optimum-rbln 0.7.4a2__py3-none-any.whl → 0.7.4a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/modeling.py +8 -1
- optimum/rbln/modeling_base.py +0 -5
- optimum/rbln/ops/__init__.py +3 -7
- optimum/rbln/ops/attn.py +271 -207
- optimum/rbln/ops/flash_attn.py +161 -67
- optimum/rbln/ops/kv_cache_update.py +4 -40
- optimum/rbln/transformers/models/bart/__init__.py +1 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +10 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +80 -94
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +17 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +12 -21
- optimum/rbln/transformers/models/t5/__init__.py +1 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +3 -37
- optimum/rbln/transformers/models/t5/t5_architecture.py +3 -4
- optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
- optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py +12 -22
- optimum/rbln/transformers/models/whisper/__init__.py +1 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +0 -1
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +20 -32
- {optimum_rbln-0.7.4a2.dist-info → optimum_rbln-0.7.4a4.dist-info}/METADATA +1 -1
- {optimum_rbln-0.7.4a2.dist-info → optimum_rbln-0.7.4a4.dist-info}/RECORD +24 -24
- {optimum_rbln-0.7.4a2.dist-info → optimum_rbln-0.7.4a4.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.4a2.dist-info → optimum_rbln-0.7.4a4.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/ops/flash_attn.py
CHANGED
@@ -12,71 +12,165 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from functools import lru_cache
|
16
|
-
|
17
15
|
import torch
|
18
|
-
from
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
torch.
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
16
|
+
from torch import Tensor
|
17
|
+
|
18
|
+
|
19
|
+
@torch.library.custom_op(
|
20
|
+
"rbln_custom_ops::paged_flash_attn_decode",
|
21
|
+
mutates_args=(["kcache", "vcache"]),
|
22
|
+
)
|
23
|
+
def paged_flash_attn_decode(
|
24
|
+
q: Tensor,
|
25
|
+
k: Tensor,
|
26
|
+
v: Tensor,
|
27
|
+
mask: Tensor,
|
28
|
+
kcache: Tensor,
|
29
|
+
vcache: Tensor,
|
30
|
+
seq: Tensor,
|
31
|
+
scale: Tensor,
|
32
|
+
block_table: Tensor,
|
33
|
+
block_size: int,
|
34
|
+
partition: int,
|
35
|
+
) -> Tensor:
|
36
|
+
"""Defines the computation pattern for fused flash attention with KV cache for decoding.
|
37
|
+
|
38
|
+
Returns a tensor with the same shape as q.
|
39
|
+
"""
|
40
|
+
return torch.empty_like(q)
|
41
|
+
|
42
|
+
|
43
|
+
@paged_flash_attn_decode.register_fake
|
44
|
+
def paged_flash_attn_decode_fake(
|
45
|
+
q: Tensor,
|
46
|
+
k: Tensor,
|
47
|
+
v: Tensor,
|
48
|
+
mask: Tensor,
|
49
|
+
kcache: Tensor,
|
50
|
+
vcache: Tensor,
|
51
|
+
seq: Tensor,
|
52
|
+
scale: Tensor,
|
53
|
+
block_table: Tensor,
|
54
|
+
block_size: int,
|
55
|
+
partition: int,
|
56
|
+
) -> Tensor:
|
57
|
+
return torch.empty_like(q)
|
58
|
+
|
59
|
+
|
60
|
+
@torch.library.custom_op(
|
61
|
+
"rbln_custom_ops::paged_flash_attn_prefill",
|
62
|
+
mutates_args=(["kcache", "vcache"]),
|
63
|
+
)
|
64
|
+
def paged_flash_attn_prefill(
|
65
|
+
q: Tensor,
|
66
|
+
k: Tensor,
|
67
|
+
v: Tensor,
|
68
|
+
mask: Tensor,
|
69
|
+
kcache: Tensor,
|
70
|
+
vcache: Tensor,
|
71
|
+
seq: Tensor,
|
72
|
+
scale: Tensor,
|
73
|
+
block_table: Tensor,
|
74
|
+
block_size: int,
|
75
|
+
partition: int,
|
76
|
+
) -> Tensor:
|
77
|
+
"""Defines the computation pattern for fused flash attention with KV cache for prefill.
|
78
|
+
|
79
|
+
Returns a tensor with the same shape as q.
|
80
|
+
"""
|
81
|
+
return torch.empty_like(q)
|
82
|
+
|
83
|
+
|
84
|
+
@paged_flash_attn_prefill.register_fake
|
85
|
+
def paged_flash_attn_prefill_fake(
|
86
|
+
q: Tensor,
|
87
|
+
k: Tensor,
|
88
|
+
v: Tensor,
|
89
|
+
mask: Tensor,
|
90
|
+
kcache: Tensor,
|
91
|
+
vcache: Tensor,
|
92
|
+
seq: Tensor,
|
93
|
+
scale: Tensor,
|
94
|
+
block_table: Tensor,
|
95
|
+
block_size: int,
|
96
|
+
partition: int,
|
97
|
+
) -> Tensor:
|
98
|
+
return torch.empty_like(q)
|
99
|
+
|
100
|
+
|
101
|
+
@torch.library.custom_op(
|
102
|
+
"rbln_custom_ops::paged_flash_causal_attn_decode",
|
103
|
+
mutates_args=(["kcache", "vcache"]),
|
104
|
+
)
|
105
|
+
def paged_flash_causal_attn_decode(
|
106
|
+
q: Tensor,
|
107
|
+
k: Tensor,
|
108
|
+
v: Tensor,
|
109
|
+
kcache: Tensor,
|
110
|
+
vcache: Tensor,
|
111
|
+
seq: Tensor,
|
112
|
+
scale: Tensor,
|
113
|
+
block_table: Tensor,
|
114
|
+
block_size: int,
|
115
|
+
partition: int,
|
116
|
+
) -> Tensor:
|
117
|
+
"""Defines the computation pattern for fused causal flash attention with KV cache for decoding.
|
118
|
+
|
119
|
+
Returns a tensor with the same shape as q.
|
120
|
+
"""
|
121
|
+
return torch.empty_like(q)
|
122
|
+
|
123
|
+
|
124
|
+
@paged_flash_causal_attn_decode.register_fake
|
125
|
+
def paged_flash_causal_attn_decode_fake(
|
126
|
+
q: Tensor,
|
127
|
+
k: Tensor,
|
128
|
+
v: Tensor,
|
129
|
+
kcache: Tensor,
|
130
|
+
vcache: Tensor,
|
131
|
+
seq: Tensor,
|
132
|
+
scale: Tensor,
|
133
|
+
block_table: Tensor,
|
134
|
+
block_size: int,
|
135
|
+
partition: int,
|
136
|
+
) -> Tensor:
|
137
|
+
return torch.empty_like(q)
|
138
|
+
|
139
|
+
|
140
|
+
@torch.library.custom_op(
|
141
|
+
"rbln_custom_ops::paged_flash_causal_attn_prefill",
|
142
|
+
mutates_args=(["kcache", "vcache"]),
|
143
|
+
)
|
144
|
+
def paged_flash_causal_attn_prefill(
|
145
|
+
q: Tensor,
|
146
|
+
k: Tensor,
|
147
|
+
v: Tensor,
|
148
|
+
kcache: Tensor,
|
149
|
+
vcache: Tensor,
|
150
|
+
seq: Tensor,
|
151
|
+
scale: Tensor,
|
152
|
+
block_table: Tensor,
|
153
|
+
block_size: int,
|
154
|
+
partition: int,
|
155
|
+
) -> Tensor:
|
156
|
+
"""Defines the computation pattern for fused causal flash attention with KV cache for prefill.
|
157
|
+
|
158
|
+
Returns a tensor with the same shape as q.
|
159
|
+
"""
|
160
|
+
return torch.empty_like(q)
|
161
|
+
|
162
|
+
|
163
|
+
@paged_flash_causal_attn_prefill.register_fake
|
164
|
+
def paged_flash_causal_attn_prefill_fake(
|
165
|
+
q: Tensor,
|
166
|
+
k: Tensor,
|
167
|
+
v: Tensor,
|
168
|
+
kcache: Tensor,
|
169
|
+
vcache: Tensor,
|
170
|
+
seq: Tensor,
|
171
|
+
scale: Tensor,
|
172
|
+
block_table: Tensor,
|
173
|
+
block_size: int,
|
174
|
+
partition: int,
|
175
|
+
) -> Tensor:
|
176
|
+
return torch.empty_like(q)
|
@@ -12,49 +12,13 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from functools import lru_cache
|
16
|
-
|
17
15
|
import torch
|
18
|
-
from
|
19
|
-
|
16
|
+
from torch import Tensor
|
20
17
|
|
21
|
-
if is_torch_greater_or_equal_than_2_4:
|
22
|
-
register_fake = torch.library.register_fake
|
23
|
-
else:
|
24
|
-
register_fake = torch.library.impl_abstract
|
25
18
|
|
26
|
-
|
27
|
-
|
28
|
-
def register_rbln_custom_cache_update():
|
19
|
+
@torch.library.custom_op("rbln_custom_ops::rbln_cache_update", mutates_args=(["cache"]))
|
20
|
+
def rbln_cache_update(cache: Tensor, state: Tensor, position: Tensor, axis: Tensor) -> Tensor:
|
29
21
|
# Define the RBLN custom operation "rbln_cache_update" which updates a cache tensor with a given state tensor.
|
30
22
|
# This operation is designed to perform in-place updates directly on the device without needing to transfer the cache back to the host.
|
31
23
|
# The `position` parameter specifies the start index for the update along the specified axis, allowing flexible updates to any part of the cache tensor.
|
32
|
-
torch.
|
33
|
-
|
34
|
-
# Implementation of the "rbln_cache_update" operation for the CPU.
|
35
|
-
@torch.library.impl("rbln_custom_ops::rbln_cache_update", "cpu")
|
36
|
-
def rbln_cache_update_cpu(cache, state, position, axis):
|
37
|
-
assert position.dim() == 0
|
38
|
-
assert axis.dim() == 0
|
39
|
-
|
40
|
-
# Calculate the start (s) and end (e) indices for the update based on the position and the shape of the state tensor along the specified axis.
|
41
|
-
s = position # Start index for the update, specified by the position.
|
42
|
-
e = (
|
43
|
-
position + state.shape[axis]
|
44
|
-
) # End index is determined by adding the size of the state along the given axis.
|
45
|
-
|
46
|
-
# Update the specified portion of the cache tensor with the state tensor, using `slice_scatter`.
|
47
|
-
# This operation modifies the cache tensor in-place directly on the device, avoiding any unnecessary transfers between host and device.
|
48
|
-
cache.slice_scatter(state, dim=axis, start=s, end=e)
|
49
|
-
|
50
|
-
# 'rbln_cache_update' is an in-place operation that isn't tracked in JIT trace, so a dummy output was added to the return value.
|
51
|
-
return torch.empty([256])
|
52
|
-
|
53
|
-
# Register a "fake" implementation of the "rbln_cache_update" operation.
|
54
|
-
# This serves as an abstract definition for the RBLN compiler to recognize the operation and generate an optimized implementation.
|
55
|
-
@register_fake("rbln_custom_ops::rbln_cache_update")
|
56
|
-
def rbln_cache_update_abstract(cache, state, position, axis):
|
57
|
-
# Return a tensor with the same shape as the input cache tensor.
|
58
|
-
# This is a placeholder for the abstract implementation and does not perform any actual computation.
|
59
|
-
# Like the actual implementation, the abstraction assumes in-place device-side updates.
|
60
|
-
return torch.empty([256])
|
24
|
+
return torch.empty_like(cache)
|
@@ -12,4 +12,14 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
from ....ops import (
|
16
|
+
paged_attn_decode,
|
17
|
+
paged_attn_prefill,
|
18
|
+
paged_causal_attn_decode,
|
19
|
+
paged_causal_attn_prefill,
|
20
|
+
paged_flash_attn_decode,
|
21
|
+
paged_flash_attn_prefill,
|
22
|
+
paged_flash_causal_attn_decode,
|
23
|
+
paged_flash_causal_attn_prefill,
|
24
|
+
)
|
15
25
|
from .modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM
|
@@ -19,12 +19,6 @@ import torch
|
|
19
19
|
from torch import nn
|
20
20
|
from transformers import PretrainedConfig, PreTrainedModel
|
21
21
|
|
22
|
-
from ....ops import (
|
23
|
-
register_rbln_custom_paged_attention,
|
24
|
-
register_rbln_custom_paged_causal_attention,
|
25
|
-
register_rbln_custom_paged_flash_attention,
|
26
|
-
register_rbln_custom_paged_flash_causal_attention,
|
27
|
-
)
|
28
22
|
from ....utils import logging
|
29
23
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
30
24
|
|
@@ -162,16 +156,8 @@ class DecoderOnlyWrapper(nn.Module):
|
|
162
156
|
self.use_attention_mask = use_attention_mask
|
163
157
|
if self.attn_impl == "flash_attn":
|
164
158
|
self.kvcache_partition_len = kvcache_partition_len or DEFAULT_FLASH_ATTN_PARTITION_LENGTH
|
165
|
-
if self.use_attention_mask:
|
166
|
-
register_rbln_custom_paged_flash_attention()
|
167
|
-
else:
|
168
|
-
register_rbln_custom_paged_flash_causal_attention()
|
169
159
|
elif self.attn_impl == "eager":
|
170
160
|
self.kvcache_partition_len = None
|
171
|
-
if self.use_attention_mask:
|
172
|
-
register_rbln_custom_paged_attention()
|
173
|
-
else:
|
174
|
-
register_rbln_custom_paged_causal_attention()
|
175
161
|
else:
|
176
162
|
raise ValueError(f"Unknown attn_impl : {self.attn_impl}")
|
177
163
|
|
@@ -756,55 +742,55 @@ class AttentionOp(nn.Module):
|
|
756
742
|
if self.phase == "decode":
|
757
743
|
if self.use_attention_mask:
|
758
744
|
attn_output = torch.ops.rbln_custom_ops.paged_attn_decode(
|
759
|
-
query_state,
|
760
|
-
key_state,
|
761
|
-
value_state,
|
762
|
-
attn_mask,
|
763
|
-
past_key_state.unsqueeze(2),
|
764
|
-
past_value_state.unsqueeze(2),
|
765
|
-
seq_position,
|
766
|
-
scale,
|
767
|
-
block_tables,
|
768
|
-
block_size,
|
745
|
+
q=query_state,
|
746
|
+
k=key_state,
|
747
|
+
v=value_state,
|
748
|
+
mask=attn_mask,
|
749
|
+
kcache=past_key_state.unsqueeze(2),
|
750
|
+
vcache=past_value_state.unsqueeze(2),
|
751
|
+
seq=seq_position,
|
752
|
+
scale=scale,
|
753
|
+
block_table=block_tables,
|
754
|
+
block_size=block_size,
|
769
755
|
)
|
770
756
|
else:
|
771
757
|
attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_decode(
|
772
|
-
query_state,
|
773
|
-
key_state,
|
774
|
-
value_state,
|
775
|
-
past_key_state.unsqueeze(2),
|
776
|
-
past_value_state.unsqueeze(2),
|
777
|
-
seq_position,
|
778
|
-
scale,
|
779
|
-
block_tables,
|
780
|
-
block_size,
|
758
|
+
q=query_state,
|
759
|
+
k=key_state,
|
760
|
+
v=value_state,
|
761
|
+
kcache=past_key_state.unsqueeze(2),
|
762
|
+
vcache=past_value_state.unsqueeze(2),
|
763
|
+
seq=seq_position,
|
764
|
+
scale=scale,
|
765
|
+
block_table=block_tables,
|
766
|
+
block_size=block_size,
|
781
767
|
)
|
782
768
|
|
783
769
|
else:
|
784
770
|
if self.use_attention_mask:
|
785
771
|
attn_output = torch.ops.rbln_custom_ops.paged_attn_prefill(
|
786
|
-
query_state,
|
787
|
-
key_state,
|
788
|
-
value_state,
|
789
|
-
attn_mask,
|
790
|
-
past_key_state.unsqueeze(2),
|
791
|
-
past_value_state.unsqueeze(2),
|
792
|
-
seq_position,
|
793
|
-
scale,
|
794
|
-
block_tables,
|
795
|
-
block_size,
|
772
|
+
q=query_state,
|
773
|
+
k=key_state,
|
774
|
+
v=value_state,
|
775
|
+
mask=attn_mask,
|
776
|
+
kcache=past_key_state.unsqueeze(2),
|
777
|
+
vcache=past_value_state.unsqueeze(2),
|
778
|
+
seq=seq_position,
|
779
|
+
scale=scale,
|
780
|
+
block_table=block_tables,
|
781
|
+
block_size=block_size,
|
796
782
|
)
|
797
783
|
else:
|
798
784
|
attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_prefill(
|
799
|
-
query_state,
|
800
|
-
key_state,
|
801
|
-
value_state,
|
802
|
-
past_key_state.unsqueeze(2),
|
803
|
-
past_value_state.unsqueeze(2),
|
804
|
-
seq_position,
|
805
|
-
scale,
|
806
|
-
block_tables,
|
807
|
-
block_size,
|
785
|
+
q=query_state,
|
786
|
+
k=key_state,
|
787
|
+
v=value_state,
|
788
|
+
kcache=past_key_state.unsqueeze(2),
|
789
|
+
vcache=past_value_state.unsqueeze(2),
|
790
|
+
seq=seq_position,
|
791
|
+
scale=scale,
|
792
|
+
block_table=block_tables,
|
793
|
+
block_size=block_size,
|
808
794
|
)
|
809
795
|
|
810
796
|
attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
|
@@ -1015,58 +1001,58 @@ class FlashAttentionOp(AttentionOp):
|
|
1015
1001
|
if self.phase == "decode":
|
1016
1002
|
if self.use_attention_mask:
|
1017
1003
|
attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_decode(
|
1018
|
-
query_state,
|
1019
|
-
key_state,
|
1020
|
-
value_state,
|
1021
|
-
attn_mask,
|
1022
|
-
past_key_state.unsqueeze(2),
|
1023
|
-
past_value_state.unsqueeze(2),
|
1024
|
-
seq_position,
|
1025
|
-
scale,
|
1026
|
-
block_tables,
|
1027
|
-
kvcache_block_size,
|
1028
|
-
self.kvcache_partition_size,
|
1004
|
+
q=query_state,
|
1005
|
+
k=key_state,
|
1006
|
+
v=value_state,
|
1007
|
+
mask=attn_mask,
|
1008
|
+
kcache=past_key_state.unsqueeze(2),
|
1009
|
+
vcache=past_value_state.unsqueeze(2),
|
1010
|
+
seq=seq_position,
|
1011
|
+
scale=scale,
|
1012
|
+
block_table=block_tables,
|
1013
|
+
block_size=kvcache_block_size,
|
1014
|
+
partition=self.kvcache_partition_size,
|
1029
1015
|
)
|
1030
1016
|
else:
|
1031
1017
|
attn_output = torch.ops.rbln_custom_ops.paged_flash_causal_attn_decode(
|
1032
|
-
query_state,
|
1033
|
-
key_state,
|
1034
|
-
value_state,
|
1035
|
-
past_key_state.unsqueeze(2),
|
1036
|
-
past_value_state.unsqueeze(2),
|
1037
|
-
seq_position,
|
1038
|
-
scale,
|
1039
|
-
block_tables,
|
1040
|
-
kvcache_block_size,
|
1041
|
-
self.kvcache_partition_size,
|
1018
|
+
q=query_state,
|
1019
|
+
k=key_state,
|
1020
|
+
v=value_state,
|
1021
|
+
kcache=past_key_state.unsqueeze(2),
|
1022
|
+
vcache=past_value_state.unsqueeze(2),
|
1023
|
+
seq=seq_position,
|
1024
|
+
scale=scale,
|
1025
|
+
block_table=block_tables,
|
1026
|
+
block_size=kvcache_block_size,
|
1027
|
+
partition=self.kvcache_partition_size,
|
1042
1028
|
)
|
1043
1029
|
else:
|
1044
1030
|
if self.use_attention_mask:
|
1045
1031
|
attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_prefill(
|
1046
|
-
query_state,
|
1047
|
-
key_state,
|
1048
|
-
value_state,
|
1049
|
-
attn_mask,
|
1050
|
-
past_key_state.unsqueeze(2),
|
1051
|
-
past_value_state.unsqueeze(2),
|
1052
|
-
seq_position,
|
1053
|
-
scale,
|
1054
|
-
block_tables,
|
1055
|
-
kvcache_block_size,
|
1056
|
-
self.kvcache_partition_size,
|
1032
|
+
q=query_state,
|
1033
|
+
k=key_state,
|
1034
|
+
v=value_state,
|
1035
|
+
mask=attn_mask,
|
1036
|
+
kcache=past_key_state.unsqueeze(2),
|
1037
|
+
vcache=past_value_state.unsqueeze(2),
|
1038
|
+
seq=seq_position,
|
1039
|
+
scale=scale,
|
1040
|
+
block_table=block_tables,
|
1041
|
+
block_size=kvcache_block_size,
|
1042
|
+
partition=self.kvcache_partition_size,
|
1057
1043
|
)
|
1058
1044
|
else:
|
1059
1045
|
attn_output = torch.ops.rbln_custom_ops.paged_flash_causal_attn_prefill(
|
1060
|
-
query_state,
|
1061
|
-
key_state,
|
1062
|
-
value_state,
|
1063
|
-
past_key_state.unsqueeze(2),
|
1064
|
-
past_value_state.unsqueeze(2),
|
1065
|
-
seq_position,
|
1066
|
-
scale,
|
1067
|
-
block_tables,
|
1068
|
-
kvcache_block_size,
|
1069
|
-
self.kvcache_partition_size,
|
1046
|
+
q=query_state,
|
1047
|
+
k=key_state,
|
1048
|
+
v=value_state,
|
1049
|
+
kcache=past_key_state.unsqueeze(2),
|
1050
|
+
vcache=past_value_state.unsqueeze(2),
|
1051
|
+
seq=seq_position,
|
1052
|
+
scale=scale,
|
1053
|
+
block_table=block_tables,
|
1054
|
+
block_size=kvcache_block_size,
|
1055
|
+
partition=self.kvcache_partition_size,
|
1070
1056
|
)
|
1071
1057
|
|
1072
1058
|
# reshape for removing repeat_kv
|
@@ -247,19 +247,23 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
247
247
|
enc_input_info = [
|
248
248
|
("input_ids", [1, rbln_enc_max_seq_len], "int64"),
|
249
249
|
("attention_mask", [1, rbln_enc_max_seq_len], "float32"),
|
250
|
-
(
|
251
|
-
"cross_key_value_states",
|
252
|
-
[
|
253
|
-
n_layer * 2,
|
254
|
-
rbln_batch_size,
|
255
|
-
n_head,
|
256
|
-
rbln_enc_max_seq_len,
|
257
|
-
d_kv,
|
258
|
-
],
|
259
|
-
"float32",
|
260
|
-
),
|
261
250
|
("block_tables", [1], "int16"),
|
262
251
|
]
|
252
|
+
enc_input_info.extend(
|
253
|
+
[
|
254
|
+
(
|
255
|
+
f"cross_key_value_states_{i}",
|
256
|
+
[
|
257
|
+
rbln_batch_size,
|
258
|
+
n_head,
|
259
|
+
rbln_enc_max_seq_len,
|
260
|
+
d_kv,
|
261
|
+
],
|
262
|
+
"float32",
|
263
|
+
)
|
264
|
+
for i in range(n_layer * 2)
|
265
|
+
]
|
266
|
+
)
|
263
267
|
|
264
268
|
dec_input_info = [
|
265
269
|
("input_ids", [rbln_batch_size, 1], "int64"),
|
@@ -274,9 +278,8 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
274
278
|
dec_input_info.extend(
|
275
279
|
[
|
276
280
|
(
|
277
|
-
"
|
281
|
+
f"cross_key_value_states_{i}",
|
278
282
|
[
|
279
|
-
n_layer * 2,
|
280
283
|
rbln_batch_size,
|
281
284
|
n_head,
|
282
285
|
rbln_enc_max_seq_len,
|
@@ -284,6 +287,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
284
287
|
],
|
285
288
|
"float32",
|
286
289
|
)
|
290
|
+
for i in range(n_layer * 2)
|
287
291
|
]
|
288
292
|
)
|
289
293
|
dec_input_info.extend(
|
@@ -18,12 +18,6 @@ import torch
|
|
18
18
|
from torch import nn
|
19
19
|
from transformers.utils import logging
|
20
20
|
|
21
|
-
from ....ops import (
|
22
|
-
register_rbln_custom_cache_update,
|
23
|
-
register_rbln_custom_paged_attention,
|
24
|
-
register_rbln_custom_paged_causal_attention,
|
25
|
-
)
|
26
|
-
|
27
21
|
|
28
22
|
logger = logging.get_logger(__name__)
|
29
23
|
|
@@ -59,7 +53,6 @@ class Seq2SeqEncoderWrapper(nn.Module):
|
|
59
53
|
|
60
54
|
def __init__(self, model: nn.Module, enc_max_seq_len: int):
|
61
55
|
super().__init__()
|
62
|
-
register_rbln_custom_cache_update()
|
63
56
|
self.config = model.config
|
64
57
|
self.encoder = model.get_encoder()
|
65
58
|
self.encoder_max_length = enc_max_seq_len
|
@@ -90,8 +83,8 @@ class Seq2SeqEncoderWrapper(nn.Module):
|
|
90
83
|
self,
|
91
84
|
input_ids: torch.Tensor,
|
92
85
|
attention_mask: torch.Tensor,
|
93
|
-
cross_key_values: torch.Tensor,
|
94
86
|
b_idx: torch.Tensor,
|
87
|
+
*cross_key_values: Tuple[torch.Tensor],
|
95
88
|
) -> Tuple[torch.Tensor]:
|
96
89
|
# 1. get encoder last_hidden_states
|
97
90
|
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
@@ -110,13 +103,15 @@ class Seq2SeqEncoderWrapper(nn.Module):
|
|
110
103
|
cross_kv.append(past_k)
|
111
104
|
cross_kv.append(past_v)
|
112
105
|
|
113
|
-
cross_kv = torch.stack(cross_kv, dim=0)
|
114
|
-
|
115
106
|
# 3. update the cross_attention's past_key_value direct to the device-dram for optimization.
|
116
|
-
batch_axis = torch.tensor(
|
117
|
-
|
107
|
+
batch_axis = torch.tensor(0, dtype=torch.int16)
|
108
|
+
cross_key_values = list(cross_key_values)
|
109
|
+
for i in range(self.n_layer * 2):
|
110
|
+
cross_key_values[i] = torch.ops.rbln_custom_ops.rbln_cache_update(
|
111
|
+
cross_key_values[i], cross_kv[i], b_idx[0], batch_axis
|
112
|
+
)
|
118
113
|
|
119
|
-
return
|
114
|
+
return cross_key_values
|
120
115
|
|
121
116
|
|
122
117
|
class Seq2SeqDecoderWrapper(nn.Module):
|
@@ -146,11 +141,6 @@ class Seq2SeqDecoderWrapper(nn.Module):
|
|
146
141
|
It is inspired by the BART architecture, but it is designed to be flexible and can be overridden
|
147
142
|
by subclasses to modify or add custom attributes as necessary.
|
148
143
|
"""
|
149
|
-
if self.use_attention_mask:
|
150
|
-
register_rbln_custom_paged_attention()
|
151
|
-
else:
|
152
|
-
register_rbln_custom_paged_causal_attention()
|
153
|
-
|
154
144
|
self.num_layers = self.config.decoder_layers
|
155
145
|
self.conditional_generation = self.convert_to_rbln_conditional_generation(model)
|
156
146
|
|
@@ -176,16 +166,17 @@ class Seq2SeqDecoderWrapper(nn.Module):
|
|
176
166
|
encoder_attention_mask,
|
177
167
|
cache_position,
|
178
168
|
block_tables,
|
179
|
-
|
180
|
-
*self_kv_cache,
|
169
|
+
*kv_cache,
|
181
170
|
) = args
|
182
171
|
|
183
172
|
else:
|
184
173
|
attention_mask = None
|
185
|
-
(input_ids, encoder_attention_mask, cache_position, block_tables,
|
174
|
+
(input_ids, encoder_attention_mask, cache_position, block_tables, *kv_cache) = args
|
186
175
|
|
187
176
|
self_past_key_values = ()
|
188
177
|
cross_past_key_values = ()
|
178
|
+
self_kv_cache = kv_cache[self.num_layers * 2 :]
|
179
|
+
cross_kv_cache = kv_cache[: self.num_layers * 2]
|
189
180
|
for i in range(0, self.num_layers * 2, 2):
|
190
181
|
self_past_key_values = self_past_key_values + ((self_kv_cache[i], self_kv_cache[i + 1]),)
|
191
182
|
cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
|