optimum-rbln 0.7.4a1__py3-none-any.whl → 0.7.4a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/modeling.py +8 -1
- optimum/rbln/ops/__init__.py +3 -7
- optimum/rbln/ops/attn.py +271 -207
- optimum/rbln/ops/flash_attn.py +161 -67
- optimum/rbln/ops/kv_cache_update.py +4 -40
- optimum/rbln/transformers/models/bart/__init__.py +1 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +10 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +80 -94
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +39 -20
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +17 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +12 -21
- optimum/rbln/transformers/models/t5/__init__.py +1 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +3 -4
- optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
- optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py +12 -22
- optimum/rbln/transformers/models/whisper/__init__.py +1 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +0 -1
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +22 -34
- {optimum_rbln-0.7.4a1.dist-info → optimum_rbln-0.7.4a3.dist-info}/METADATA +1 -1
- {optimum_rbln-0.7.4a1.dist-info → optimum_rbln-0.7.4a3.dist-info}/RECORD +23 -23
- {optimum_rbln-0.7.4a1.dist-info → optimum_rbln-0.7.4a3.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.4a1.dist-info → optimum_rbln-0.7.4a3.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/ops/flash_attn.py
CHANGED
@@ -12,71 +12,165 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from functools import lru_cache
|
16
|
-
|
17
15
|
import torch
|
18
|
-
from
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
torch.
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
16
|
+
from torch import Tensor
|
17
|
+
|
18
|
+
|
19
|
+
@torch.library.custom_op(
|
20
|
+
"rbln_custom_ops::paged_flash_attn_decode",
|
21
|
+
mutates_args=(["kcache", "vcache"]),
|
22
|
+
)
|
23
|
+
def paged_flash_attn_decode(
|
24
|
+
q: Tensor,
|
25
|
+
k: Tensor,
|
26
|
+
v: Tensor,
|
27
|
+
mask: Tensor,
|
28
|
+
kcache: Tensor,
|
29
|
+
vcache: Tensor,
|
30
|
+
seq: Tensor,
|
31
|
+
scale: Tensor,
|
32
|
+
block_table: Tensor,
|
33
|
+
block_size: int,
|
34
|
+
partition: int,
|
35
|
+
) -> Tensor:
|
36
|
+
"""Defines the computation pattern for fused flash attention with KV cache for decoding.
|
37
|
+
|
38
|
+
Returns a tensor with the same shape as q.
|
39
|
+
"""
|
40
|
+
return torch.empty_like(q)
|
41
|
+
|
42
|
+
|
43
|
+
@paged_flash_attn_decode.register_fake
|
44
|
+
def paged_flash_attn_decode_fake(
|
45
|
+
q: Tensor,
|
46
|
+
k: Tensor,
|
47
|
+
v: Tensor,
|
48
|
+
mask: Tensor,
|
49
|
+
kcache: Tensor,
|
50
|
+
vcache: Tensor,
|
51
|
+
seq: Tensor,
|
52
|
+
scale: Tensor,
|
53
|
+
block_table: Tensor,
|
54
|
+
block_size: int,
|
55
|
+
partition: int,
|
56
|
+
) -> Tensor:
|
57
|
+
return torch.empty_like(q)
|
58
|
+
|
59
|
+
|
60
|
+
@torch.library.custom_op(
|
61
|
+
"rbln_custom_ops::paged_flash_attn_prefill",
|
62
|
+
mutates_args=(["kcache", "vcache"]),
|
63
|
+
)
|
64
|
+
def paged_flash_attn_prefill(
|
65
|
+
q: Tensor,
|
66
|
+
k: Tensor,
|
67
|
+
v: Tensor,
|
68
|
+
mask: Tensor,
|
69
|
+
kcache: Tensor,
|
70
|
+
vcache: Tensor,
|
71
|
+
seq: Tensor,
|
72
|
+
scale: Tensor,
|
73
|
+
block_table: Tensor,
|
74
|
+
block_size: int,
|
75
|
+
partition: int,
|
76
|
+
) -> Tensor:
|
77
|
+
"""Defines the computation pattern for fused flash attention with KV cache for prefill.
|
78
|
+
|
79
|
+
Returns a tensor with the same shape as q.
|
80
|
+
"""
|
81
|
+
return torch.empty_like(q)
|
82
|
+
|
83
|
+
|
84
|
+
@paged_flash_attn_prefill.register_fake
|
85
|
+
def paged_flash_attn_prefill_fake(
|
86
|
+
q: Tensor,
|
87
|
+
k: Tensor,
|
88
|
+
v: Tensor,
|
89
|
+
mask: Tensor,
|
90
|
+
kcache: Tensor,
|
91
|
+
vcache: Tensor,
|
92
|
+
seq: Tensor,
|
93
|
+
scale: Tensor,
|
94
|
+
block_table: Tensor,
|
95
|
+
block_size: int,
|
96
|
+
partition: int,
|
97
|
+
) -> Tensor:
|
98
|
+
return torch.empty_like(q)
|
99
|
+
|
100
|
+
|
101
|
+
@torch.library.custom_op(
|
102
|
+
"rbln_custom_ops::paged_flash_causal_attn_decode",
|
103
|
+
mutates_args=(["kcache", "vcache"]),
|
104
|
+
)
|
105
|
+
def paged_flash_causal_attn_decode(
|
106
|
+
q: Tensor,
|
107
|
+
k: Tensor,
|
108
|
+
v: Tensor,
|
109
|
+
kcache: Tensor,
|
110
|
+
vcache: Tensor,
|
111
|
+
seq: Tensor,
|
112
|
+
scale: Tensor,
|
113
|
+
block_table: Tensor,
|
114
|
+
block_size: int,
|
115
|
+
partition: int,
|
116
|
+
) -> Tensor:
|
117
|
+
"""Defines the computation pattern for fused causal flash attention with KV cache for decoding.
|
118
|
+
|
119
|
+
Returns a tensor with the same shape as q.
|
120
|
+
"""
|
121
|
+
return torch.empty_like(q)
|
122
|
+
|
123
|
+
|
124
|
+
@paged_flash_causal_attn_decode.register_fake
|
125
|
+
def paged_flash_causal_attn_decode_fake(
|
126
|
+
q: Tensor,
|
127
|
+
k: Tensor,
|
128
|
+
v: Tensor,
|
129
|
+
kcache: Tensor,
|
130
|
+
vcache: Tensor,
|
131
|
+
seq: Tensor,
|
132
|
+
scale: Tensor,
|
133
|
+
block_table: Tensor,
|
134
|
+
block_size: int,
|
135
|
+
partition: int,
|
136
|
+
) -> Tensor:
|
137
|
+
return torch.empty_like(q)
|
138
|
+
|
139
|
+
|
140
|
+
@torch.library.custom_op(
|
141
|
+
"rbln_custom_ops::paged_flash_causal_attn_prefill",
|
142
|
+
mutates_args=(["kcache", "vcache"]),
|
143
|
+
)
|
144
|
+
def paged_flash_causal_attn_prefill(
|
145
|
+
q: Tensor,
|
146
|
+
k: Tensor,
|
147
|
+
v: Tensor,
|
148
|
+
kcache: Tensor,
|
149
|
+
vcache: Tensor,
|
150
|
+
seq: Tensor,
|
151
|
+
scale: Tensor,
|
152
|
+
block_table: Tensor,
|
153
|
+
block_size: int,
|
154
|
+
partition: int,
|
155
|
+
) -> Tensor:
|
156
|
+
"""Defines the computation pattern for fused causal flash attention with KV cache for prefill.
|
157
|
+
|
158
|
+
Returns a tensor with the same shape as q.
|
159
|
+
"""
|
160
|
+
return torch.empty_like(q)
|
161
|
+
|
162
|
+
|
163
|
+
@paged_flash_causal_attn_prefill.register_fake
|
164
|
+
def paged_flash_causal_attn_prefill_fake(
|
165
|
+
q: Tensor,
|
166
|
+
k: Tensor,
|
167
|
+
v: Tensor,
|
168
|
+
kcache: Tensor,
|
169
|
+
vcache: Tensor,
|
170
|
+
seq: Tensor,
|
171
|
+
scale: Tensor,
|
172
|
+
block_table: Tensor,
|
173
|
+
block_size: int,
|
174
|
+
partition: int,
|
175
|
+
) -> Tensor:
|
176
|
+
return torch.empty_like(q)
|
@@ -12,49 +12,13 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from functools import lru_cache
|
16
|
-
|
17
15
|
import torch
|
18
|
-
from
|
19
|
-
|
16
|
+
from torch import Tensor
|
20
17
|
|
21
|
-
if is_torch_greater_or_equal_than_2_4:
|
22
|
-
register_fake = torch.library.register_fake
|
23
|
-
else:
|
24
|
-
register_fake = torch.library.impl_abstract
|
25
18
|
|
26
|
-
|
27
|
-
|
28
|
-
def register_rbln_custom_cache_update():
|
19
|
+
@torch.library.custom_op("rbln_custom_ops::rbln_cache_update", mutates_args=(["cache"]))
|
20
|
+
def rbln_cache_update(cache: Tensor, state: Tensor, position: Tensor, axis: Tensor) -> Tensor:
|
29
21
|
# Define the RBLN custom operation "rbln_cache_update" which updates a cache tensor with a given state tensor.
|
30
22
|
# This operation is designed to perform in-place updates directly on the device without needing to transfer the cache back to the host.
|
31
23
|
# The `position` parameter specifies the start index for the update along the specified axis, allowing flexible updates to any part of the cache tensor.
|
32
|
-
torch.
|
33
|
-
|
34
|
-
# Implementation of the "rbln_cache_update" operation for the CPU.
|
35
|
-
@torch.library.impl("rbln_custom_ops::rbln_cache_update", "cpu")
|
36
|
-
def rbln_cache_update_cpu(cache, state, position, axis):
|
37
|
-
assert position.dim() == 0
|
38
|
-
assert axis.dim() == 0
|
39
|
-
|
40
|
-
# Calculate the start (s) and end (e) indices for the update based on the position and the shape of the state tensor along the specified axis.
|
41
|
-
s = position # Start index for the update, specified by the position.
|
42
|
-
e = (
|
43
|
-
position + state.shape[axis]
|
44
|
-
) # End index is determined by adding the size of the state along the given axis.
|
45
|
-
|
46
|
-
# Update the specified portion of the cache tensor with the state tensor, using `slice_scatter`.
|
47
|
-
# This operation modifies the cache tensor in-place directly on the device, avoiding any unnecessary transfers between host and device.
|
48
|
-
cache.slice_scatter(state, dim=axis, start=s, end=e)
|
49
|
-
|
50
|
-
# 'rbln_cache_update' is an in-place operation that isn't tracked in JIT trace, so a dummy output was added to the return value.
|
51
|
-
return torch.empty([256])
|
52
|
-
|
53
|
-
# Register a "fake" implementation of the "rbln_cache_update" operation.
|
54
|
-
# This serves as an abstract definition for the RBLN compiler to recognize the operation and generate an optimized implementation.
|
55
|
-
@register_fake("rbln_custom_ops::rbln_cache_update")
|
56
|
-
def rbln_cache_update_abstract(cache, state, position, axis):
|
57
|
-
# Return a tensor with the same shape as the input cache tensor.
|
58
|
-
# This is a placeholder for the abstract implementation and does not perform any actual computation.
|
59
|
-
# Like the actual implementation, the abstraction assumes in-place device-side updates.
|
60
|
-
return torch.empty([256])
|
24
|
+
return torch.empty_like(cache)
|
@@ -12,4 +12,14 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
from ....ops import (
|
16
|
+
paged_attn_decode,
|
17
|
+
paged_attn_prefill,
|
18
|
+
paged_causal_attn_decode,
|
19
|
+
paged_causal_attn_prefill,
|
20
|
+
paged_flash_attn_decode,
|
21
|
+
paged_flash_attn_prefill,
|
22
|
+
paged_flash_causal_attn_decode,
|
23
|
+
paged_flash_causal_attn_prefill,
|
24
|
+
)
|
15
25
|
from .modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM
|
@@ -19,12 +19,6 @@ import torch
|
|
19
19
|
from torch import nn
|
20
20
|
from transformers import PretrainedConfig, PreTrainedModel
|
21
21
|
|
22
|
-
from ....ops import (
|
23
|
-
register_rbln_custom_paged_attention,
|
24
|
-
register_rbln_custom_paged_causal_attention,
|
25
|
-
register_rbln_custom_paged_flash_attention,
|
26
|
-
register_rbln_custom_paged_flash_causal_attention,
|
27
|
-
)
|
28
22
|
from ....utils import logging
|
29
23
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
30
24
|
|
@@ -162,16 +156,8 @@ class DecoderOnlyWrapper(nn.Module):
|
|
162
156
|
self.use_attention_mask = use_attention_mask
|
163
157
|
if self.attn_impl == "flash_attn":
|
164
158
|
self.kvcache_partition_len = kvcache_partition_len or DEFAULT_FLASH_ATTN_PARTITION_LENGTH
|
165
|
-
if self.use_attention_mask:
|
166
|
-
register_rbln_custom_paged_flash_attention()
|
167
|
-
else:
|
168
|
-
register_rbln_custom_paged_flash_causal_attention()
|
169
159
|
elif self.attn_impl == "eager":
|
170
160
|
self.kvcache_partition_len = None
|
171
|
-
if self.use_attention_mask:
|
172
|
-
register_rbln_custom_paged_attention()
|
173
|
-
else:
|
174
|
-
register_rbln_custom_paged_causal_attention()
|
175
161
|
else:
|
176
162
|
raise ValueError(f"Unknown attn_impl : {self.attn_impl}")
|
177
163
|
|
@@ -756,55 +742,55 @@ class AttentionOp(nn.Module):
|
|
756
742
|
if self.phase == "decode":
|
757
743
|
if self.use_attention_mask:
|
758
744
|
attn_output = torch.ops.rbln_custom_ops.paged_attn_decode(
|
759
|
-
query_state,
|
760
|
-
key_state,
|
761
|
-
value_state,
|
762
|
-
attn_mask,
|
763
|
-
past_key_state.unsqueeze(2),
|
764
|
-
past_value_state.unsqueeze(2),
|
765
|
-
seq_position,
|
766
|
-
scale,
|
767
|
-
block_tables,
|
768
|
-
block_size,
|
745
|
+
q=query_state,
|
746
|
+
k=key_state,
|
747
|
+
v=value_state,
|
748
|
+
mask=attn_mask,
|
749
|
+
kcache=past_key_state.unsqueeze(2),
|
750
|
+
vcache=past_value_state.unsqueeze(2),
|
751
|
+
seq=seq_position,
|
752
|
+
scale=scale,
|
753
|
+
block_table=block_tables,
|
754
|
+
block_size=block_size,
|
769
755
|
)
|
770
756
|
else:
|
771
757
|
attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_decode(
|
772
|
-
query_state,
|
773
|
-
key_state,
|
774
|
-
value_state,
|
775
|
-
past_key_state.unsqueeze(2),
|
776
|
-
past_value_state.unsqueeze(2),
|
777
|
-
seq_position,
|
778
|
-
scale,
|
779
|
-
block_tables,
|
780
|
-
block_size,
|
758
|
+
q=query_state,
|
759
|
+
k=key_state,
|
760
|
+
v=value_state,
|
761
|
+
kcache=past_key_state.unsqueeze(2),
|
762
|
+
vcache=past_value_state.unsqueeze(2),
|
763
|
+
seq=seq_position,
|
764
|
+
scale=scale,
|
765
|
+
block_table=block_tables,
|
766
|
+
block_size=block_size,
|
781
767
|
)
|
782
768
|
|
783
769
|
else:
|
784
770
|
if self.use_attention_mask:
|
785
771
|
attn_output = torch.ops.rbln_custom_ops.paged_attn_prefill(
|
786
|
-
query_state,
|
787
|
-
key_state,
|
788
|
-
value_state,
|
789
|
-
attn_mask,
|
790
|
-
past_key_state.unsqueeze(2),
|
791
|
-
past_value_state.unsqueeze(2),
|
792
|
-
seq_position,
|
793
|
-
scale,
|
794
|
-
block_tables,
|
795
|
-
block_size,
|
772
|
+
q=query_state,
|
773
|
+
k=key_state,
|
774
|
+
v=value_state,
|
775
|
+
mask=attn_mask,
|
776
|
+
kcache=past_key_state.unsqueeze(2),
|
777
|
+
vcache=past_value_state.unsqueeze(2),
|
778
|
+
seq=seq_position,
|
779
|
+
scale=scale,
|
780
|
+
block_table=block_tables,
|
781
|
+
block_size=block_size,
|
796
782
|
)
|
797
783
|
else:
|
798
784
|
attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_prefill(
|
799
|
-
query_state,
|
800
|
-
key_state,
|
801
|
-
value_state,
|
802
|
-
past_key_state.unsqueeze(2),
|
803
|
-
past_value_state.unsqueeze(2),
|
804
|
-
seq_position,
|
805
|
-
scale,
|
806
|
-
block_tables,
|
807
|
-
block_size,
|
785
|
+
q=query_state,
|
786
|
+
k=key_state,
|
787
|
+
v=value_state,
|
788
|
+
kcache=past_key_state.unsqueeze(2),
|
789
|
+
vcache=past_value_state.unsqueeze(2),
|
790
|
+
seq=seq_position,
|
791
|
+
scale=scale,
|
792
|
+
block_table=block_tables,
|
793
|
+
block_size=block_size,
|
808
794
|
)
|
809
795
|
|
810
796
|
attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
|
@@ -1015,58 +1001,58 @@ class FlashAttentionOp(AttentionOp):
|
|
1015
1001
|
if self.phase == "decode":
|
1016
1002
|
if self.use_attention_mask:
|
1017
1003
|
attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_decode(
|
1018
|
-
query_state,
|
1019
|
-
key_state,
|
1020
|
-
value_state,
|
1021
|
-
attn_mask,
|
1022
|
-
past_key_state.unsqueeze(2),
|
1023
|
-
past_value_state.unsqueeze(2),
|
1024
|
-
seq_position,
|
1025
|
-
scale,
|
1026
|
-
block_tables,
|
1027
|
-
kvcache_block_size,
|
1028
|
-
self.kvcache_partition_size,
|
1004
|
+
q=query_state,
|
1005
|
+
k=key_state,
|
1006
|
+
v=value_state,
|
1007
|
+
mask=attn_mask,
|
1008
|
+
kcache=past_key_state.unsqueeze(2),
|
1009
|
+
vcache=past_value_state.unsqueeze(2),
|
1010
|
+
seq=seq_position,
|
1011
|
+
scale=scale,
|
1012
|
+
block_table=block_tables,
|
1013
|
+
block_size=kvcache_block_size,
|
1014
|
+
partition=self.kvcache_partition_size,
|
1029
1015
|
)
|
1030
1016
|
else:
|
1031
1017
|
attn_output = torch.ops.rbln_custom_ops.paged_flash_causal_attn_decode(
|
1032
|
-
query_state,
|
1033
|
-
key_state,
|
1034
|
-
value_state,
|
1035
|
-
past_key_state.unsqueeze(2),
|
1036
|
-
past_value_state.unsqueeze(2),
|
1037
|
-
seq_position,
|
1038
|
-
scale,
|
1039
|
-
block_tables,
|
1040
|
-
kvcache_block_size,
|
1041
|
-
self.kvcache_partition_size,
|
1018
|
+
q=query_state,
|
1019
|
+
k=key_state,
|
1020
|
+
v=value_state,
|
1021
|
+
kcache=past_key_state.unsqueeze(2),
|
1022
|
+
vcache=past_value_state.unsqueeze(2),
|
1023
|
+
seq=seq_position,
|
1024
|
+
scale=scale,
|
1025
|
+
block_table=block_tables,
|
1026
|
+
block_size=kvcache_block_size,
|
1027
|
+
partition=self.kvcache_partition_size,
|
1042
1028
|
)
|
1043
1029
|
else:
|
1044
1030
|
if self.use_attention_mask:
|
1045
1031
|
attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_prefill(
|
1046
|
-
query_state,
|
1047
|
-
key_state,
|
1048
|
-
value_state,
|
1049
|
-
attn_mask,
|
1050
|
-
past_key_state.unsqueeze(2),
|
1051
|
-
past_value_state.unsqueeze(2),
|
1052
|
-
seq_position,
|
1053
|
-
scale,
|
1054
|
-
block_tables,
|
1055
|
-
kvcache_block_size,
|
1056
|
-
self.kvcache_partition_size,
|
1032
|
+
q=query_state,
|
1033
|
+
k=key_state,
|
1034
|
+
v=value_state,
|
1035
|
+
mask=attn_mask,
|
1036
|
+
kcache=past_key_state.unsqueeze(2),
|
1037
|
+
vcache=past_value_state.unsqueeze(2),
|
1038
|
+
seq=seq_position,
|
1039
|
+
scale=scale,
|
1040
|
+
block_table=block_tables,
|
1041
|
+
block_size=kvcache_block_size,
|
1042
|
+
partition=self.kvcache_partition_size,
|
1057
1043
|
)
|
1058
1044
|
else:
|
1059
1045
|
attn_output = torch.ops.rbln_custom_ops.paged_flash_causal_attn_prefill(
|
1060
|
-
query_state,
|
1061
|
-
key_state,
|
1062
|
-
value_state,
|
1063
|
-
past_key_state.unsqueeze(2),
|
1064
|
-
past_value_state.unsqueeze(2),
|
1065
|
-
seq_position,
|
1066
|
-
scale,
|
1067
|
-
block_tables,
|
1068
|
-
kvcache_block_size,
|
1069
|
-
self.kvcache_partition_size,
|
1046
|
+
q=query_state,
|
1047
|
+
k=key_state,
|
1048
|
+
v=value_state,
|
1049
|
+
kcache=past_key_state.unsqueeze(2),
|
1050
|
+
vcache=past_value_state.unsqueeze(2),
|
1051
|
+
seq=seq_position,
|
1052
|
+
scale=scale,
|
1053
|
+
block_table=block_tables,
|
1054
|
+
block_size=kvcache_block_size,
|
1055
|
+
partition=self.kvcache_partition_size,
|
1070
1056
|
)
|
1071
1057
|
|
1072
1058
|
# reshape for removing repeat_kv
|
@@ -578,11 +578,41 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
578
578
|
nbits_per_param: int,
|
579
579
|
n_model_params: int,
|
580
580
|
) -> int:
|
581
|
+
"""
|
582
|
+
We are finding max_n_blocks(x) that satisfies the following equation:
|
583
|
+
|
584
|
+
available_dram - kernel_size - buffer
|
585
|
+
- num_layers * 2 * tensor_parallel_size
|
586
|
+
* align_2MB(
|
587
|
+
x
|
588
|
+
* block_size
|
589
|
+
* align_64(head_dim)
|
590
|
+
* math.ceil(num_key_value_heads / tensor_parallel_size)
|
591
|
+
* 2
|
592
|
+
) > 0
|
593
|
+
|
594
|
+
This inequality can be rewritten as follows:
|
595
|
+
|
596
|
+
a - c * align_2MB(b * x) > 0
|
597
|
+
where
|
598
|
+
a = available_dram - kernel_size - buffer
|
599
|
+
b = block_size * align_64(head_dim) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
|
600
|
+
c = num_layers * 2 * tensor_parallel_size
|
601
|
+
|
602
|
+
We can rewrite the inequality as follows:
|
603
|
+
k > align_2MB(b*x)
|
604
|
+
where
|
605
|
+
k = a / c
|
606
|
+
|
607
|
+
After that, we can derive the following equation:
|
608
|
+
x = floor(2**21 / b * floor((k - 1) / 2**21))
|
609
|
+
"""
|
610
|
+
|
581
611
|
def align(x: int, nbytes: int) -> int:
|
582
612
|
return int(math.ceil(x / nbytes) * nbytes)
|
583
613
|
|
584
614
|
def align_2MB(x: int) -> int:
|
585
|
-
return align(x, 2
|
615
|
+
return align(x, 2**21)
|
586
616
|
|
587
617
|
num_attention_heads = getattr(config, "n_head", None) or getattr(config, "num_attention_heads")
|
588
618
|
num_layers = getattr(config, "n_layer", None) or getattr(config, "num_hidden_layers")
|
@@ -612,27 +642,16 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
612
642
|
available_dram -= kernel_size
|
613
643
|
|
614
644
|
# TODO: Accurate buffer estimation
|
615
|
-
|
616
|
-
|
617
|
-
buffer /= 4
|
618
|
-
|
645
|
+
buffer_per_core = 2**29 # 500MB per npu
|
646
|
+
buffer = buffer_per_core * tensor_parallel_size
|
619
647
|
available_dram -= buffer
|
620
648
|
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
* head_dim
|
626
|
-
* math.ceil(num_key_value_heads / tensor_parallel_size) # Shard
|
627
|
-
* 2 # (fp16)
|
628
|
-
)
|
629
|
-
* num_layers
|
630
|
-
* 2 # (k, v)
|
631
|
-
* tensor_parallel_size
|
632
|
-
)
|
633
|
-
n_blocks = available_dram // nbytes_per_block
|
649
|
+
b = kvcache_block_size * align(head_dim, 64) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
|
650
|
+
c = num_layers * 2 * tensor_parallel_size
|
651
|
+
k = available_dram / c
|
652
|
+
max_n_blocks = math.floor(2**21 / b * math.floor((k - 1) / 2**21))
|
634
653
|
|
635
|
-
return
|
654
|
+
return max_n_blocks
|
636
655
|
|
637
656
|
@classmethod
|
638
657
|
def _get_rbln_config(
|
@@ -689,7 +708,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
689
708
|
|
690
709
|
rbln_kvcache_num_blocks = (rbln_max_seq_len // rbln_kvcache_block_size) * rbln_batch_size
|
691
710
|
if rbln_attn_impl == "flash_attn":
|
692
|
-
max_num_blocks
|
711
|
+
max_num_blocks = cls.get_maximum_num_blocks(
|
693
712
|
config=model_config,
|
694
713
|
tensor_parallel_size=rbln_kwargs.get("tensor_parallel_size", 1),
|
695
714
|
kvcache_block_size=rbln_kvcache_block_size,
|
@@ -247,19 +247,23 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
247
247
|
enc_input_info = [
|
248
248
|
("input_ids", [1, rbln_enc_max_seq_len], "int64"),
|
249
249
|
("attention_mask", [1, rbln_enc_max_seq_len], "float32"),
|
250
|
-
(
|
251
|
-
"cross_key_value_states",
|
252
|
-
[
|
253
|
-
n_layer * 2,
|
254
|
-
rbln_batch_size,
|
255
|
-
n_head,
|
256
|
-
rbln_enc_max_seq_len,
|
257
|
-
d_kv,
|
258
|
-
],
|
259
|
-
"float32",
|
260
|
-
),
|
261
250
|
("block_tables", [1], "int16"),
|
262
251
|
]
|
252
|
+
enc_input_info.extend(
|
253
|
+
[
|
254
|
+
(
|
255
|
+
f"cross_key_value_states_{i}",
|
256
|
+
[
|
257
|
+
rbln_batch_size,
|
258
|
+
n_head,
|
259
|
+
rbln_enc_max_seq_len,
|
260
|
+
d_kv,
|
261
|
+
],
|
262
|
+
"float32",
|
263
|
+
)
|
264
|
+
for i in range(n_layer * 2)
|
265
|
+
]
|
266
|
+
)
|
263
267
|
|
264
268
|
dec_input_info = [
|
265
269
|
("input_ids", [rbln_batch_size, 1], "int64"),
|
@@ -274,9 +278,8 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
274
278
|
dec_input_info.extend(
|
275
279
|
[
|
276
280
|
(
|
277
|
-
"
|
281
|
+
f"cross_key_value_states_{i}",
|
278
282
|
[
|
279
|
-
n_layer * 2,
|
280
283
|
rbln_batch_size,
|
281
284
|
n_head,
|
282
285
|
rbln_enc_max_seq_len,
|
@@ -284,6 +287,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
284
287
|
],
|
285
288
|
"float32",
|
286
289
|
)
|
290
|
+
for i in range(n_layer * 2)
|
287
291
|
]
|
288
292
|
)
|
289
293
|
dec_input_info.extend(
|