optimum-rbln 0.7.3a1__py3-none-any.whl → 0.7.3a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/ops/__init__.py +4 -4
- optimum/rbln/ops/attn.py +44 -84
- optimum/rbln/ops/flash_attn.py +25 -25
- optimum/rbln/transformers/models/bart/bart_architecture.py +10 -6
- optimum/rbln/transformers/models/bart/modeling_bart.py +3 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +79 -51
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +157 -34
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +7 -2
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +7 -2
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +3 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +3 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +5 -3
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +44 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +50 -19
- optimum/rbln/transformers/models/t5/modeling_t5.py +211 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +69 -3
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +19 -24
- {optimum_rbln-0.7.3a1.dist-info → optimum_rbln-0.7.3a3.dist-info}/METADATA +1 -1
- {optimum_rbln-0.7.3a1.dist-info → optimum_rbln-0.7.3a3.dist-info}/RECORD +22 -22
- {optimum_rbln-0.7.3a1.dist-info → optimum_rbln-0.7.3a3.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.3a1.dist-info → optimum_rbln-0.7.3a3.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__version__.py
CHANGED
@@ -17,5 +17,5 @@ __version__: str
|
|
17
17
|
__version_tuple__: VERSION_TUPLE
|
18
18
|
version_tuple: VERSION_TUPLE
|
19
19
|
|
20
|
-
__version__ = version = '0.7.
|
21
|
-
__version_tuple__ = version_tuple = (0, 7, 3)
|
20
|
+
__version__ = version = '0.7.3a3'
|
21
|
+
__version_tuple__ = version_tuple = (0, 7, 3, 'a3')
|
optimum/rbln/ops/__init__.py
CHANGED
@@ -13,9 +13,9 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
from .attn import (
|
16
|
-
|
17
|
-
|
18
|
-
|
16
|
+
register_rbln_custom_add_softmax_attention,
|
17
|
+
register_rbln_custom_paged_attention,
|
18
|
+
register_rbln_custom_paged_causal_attention,
|
19
19
|
)
|
20
|
-
from .flash_attn import
|
20
|
+
from .flash_attn import register_rbln_custom_paged_flash_attention, register_rbln_custom_paged_flash_causal_attention
|
21
21
|
from .kv_cache_update import register_rbln_custom_cache_update
|
optimum/rbln/ops/attn.py
CHANGED
@@ -25,14 +25,14 @@ else:
|
|
25
25
|
|
26
26
|
|
27
27
|
@lru_cache
|
28
|
-
def
|
28
|
+
def register_rbln_custom_paged_attention():
|
29
29
|
torch.library.define(
|
30
|
-
"rbln_custom_ops::
|
31
|
-
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor[]",
|
30
|
+
"rbln_custom_ops::paged_attn_decode",
|
31
|
+
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor[]",
|
32
32
|
)
|
33
33
|
|
34
|
-
@torch.library.impl("rbln_custom_ops::
|
35
|
-
def attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale):
|
34
|
+
@torch.library.impl("rbln_custom_ops::paged_attn_decode", "cpu")
|
35
|
+
def attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size):
|
36
36
|
"""Defines the computation pattern for fused attention with KV cache updates.
|
37
37
|
|
38
38
|
IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
|
@@ -51,8 +51,10 @@ def register_rbln_custom_masked_attention():
|
|
51
51
|
- mask: [batch=1, n_heads, 1, 1, max_seq_len] - Attention mask
|
52
52
|
- kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
|
53
53
|
- vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
|
54
|
-
- seq: [1] - Current sequence position
|
54
|
+
- seq: [1, 1] - Current sequence position
|
55
55
|
- scale: [] - Attention scale factor
|
56
|
+
- block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
|
57
|
+
- block_size: [] - Number of tokens per block
|
56
58
|
|
57
59
|
Returns:
|
58
60
|
Tuple[Tensor, Tensor, Tensor]:
|
@@ -66,8 +68,8 @@ def register_rbln_custom_masked_attention():
|
|
66
68
|
torch.empty(*vcache.shape, device=vcache.device),
|
67
69
|
)
|
68
70
|
|
69
|
-
@register_fake("rbln_custom_ops::
|
70
|
-
def attn_decode_abstract(q, k, v, m, kcache, vcache, seq,
|
71
|
+
@register_fake("rbln_custom_ops::paged_attn_decode")
|
72
|
+
def attn_decode_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size):
|
71
73
|
return (
|
72
74
|
q,
|
73
75
|
torch.empty(*kcache.shape, device=kcache.device),
|
@@ -75,12 +77,12 @@ def register_rbln_custom_masked_attention():
|
|
75
77
|
)
|
76
78
|
|
77
79
|
torch.library.define(
|
78
|
-
"rbln_custom_ops::
|
79
|
-
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e) -> Tensor[]",
|
80
|
+
"rbln_custom_ops::paged_attn_prefill",
|
81
|
+
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor[]",
|
80
82
|
)
|
81
83
|
|
82
|
-
@torch.library.impl("rbln_custom_ops::
|
83
|
-
def attn_prefill_cpu(q, k, v, mask, kcache, vcache,
|
84
|
+
@torch.library.impl("rbln_custom_ops::paged_attn_prefill", "cpu")
|
85
|
+
def attn_prefill_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size):
|
84
86
|
"""Defines the computation pattern for prefill phase attention with KV cache updates.
|
85
87
|
|
86
88
|
IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
|
@@ -97,9 +99,10 @@ def register_rbln_custom_masked_attention():
|
|
97
99
|
- mask: [batch=1, 1, 1, seq_len, max_seq_len] - Attention mask
|
98
100
|
- kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
|
99
101
|
- vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
|
100
|
-
-
|
101
|
-
- seq: [1] - Starting sequence position
|
102
|
+
- seq: [1, 1] - Starting sequence position
|
102
103
|
- scale: [] - Attention scale factor
|
104
|
+
- block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
|
105
|
+
- block_size: [] - Number of tokens per block
|
103
106
|
|
104
107
|
Returns:
|
105
108
|
Tuple[Tensor, Tensor, Tensor]:
|
@@ -109,20 +112,20 @@ def register_rbln_custom_masked_attention():
|
|
109
112
|
"""
|
110
113
|
return q, kcache, vcache
|
111
114
|
|
112
|
-
@register_fake("rbln_custom_ops::
|
113
|
-
def attn_prefill_abstract(q, k, v, m, kcache, vcache,
|
115
|
+
@register_fake("rbln_custom_ops::paged_attn_prefill")
|
116
|
+
def attn_prefill_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size):
|
114
117
|
return q, kcache, vcache
|
115
118
|
|
116
119
|
|
117
120
|
@lru_cache
|
118
|
-
def
|
121
|
+
def register_rbln_custom_paged_causal_attention():
|
119
122
|
torch.library.define(
|
120
|
-
"rbln_custom_ops::
|
121
|
-
"(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor[]",
|
123
|
+
"rbln_custom_ops::paged_causal_attn_decode",
|
124
|
+
"(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor[]",
|
122
125
|
)
|
123
126
|
|
124
|
-
@torch.library.impl("rbln_custom_ops::
|
125
|
-
def attn_decode_cpu(q, k, v, kcache, vcache, seq, scale):
|
127
|
+
@torch.library.impl("rbln_custom_ops::paged_causal_attn_decode", "cpu")
|
128
|
+
def attn_decode_cpu(q, k, v, kcache, vcache, seq, scale, block_table, block_size):
|
126
129
|
"""Defines the computation pattern for fused attention with KV cache updates.
|
127
130
|
|
128
131
|
IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
|
@@ -140,8 +143,10 @@ def register_rbln_custom_causal_masked_attention():
|
|
140
143
|
- v: [batch=1, n_heads, 1, 1, head_dim] - Value states for current input
|
141
144
|
- kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
|
142
145
|
- vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
|
143
|
-
- seq: [1] -
|
146
|
+
- seq: [1, 1] - Starting sequence position
|
144
147
|
- scale: [] - Attention scale factor
|
148
|
+
- block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
|
149
|
+
- block_size: [] - Number of tokens per block
|
145
150
|
|
146
151
|
Returns:
|
147
152
|
Tuple[Tensor, Tensor, Tensor]:
|
@@ -155,8 +160,8 @@ def register_rbln_custom_causal_masked_attention():
|
|
155
160
|
torch.empty(*vcache.shape, device=vcache.device),
|
156
161
|
)
|
157
162
|
|
158
|
-
@register_fake("rbln_custom_ops::
|
159
|
-
def attn_decode_abstract(q, k, v, kcache, vcache, seq,
|
163
|
+
@register_fake("rbln_custom_ops::paged_causal_attn_decode")
|
164
|
+
def attn_decode_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size):
|
160
165
|
return (
|
161
166
|
q,
|
162
167
|
torch.empty(*kcache.shape, device=kcache.device),
|
@@ -164,12 +169,12 @@ def register_rbln_custom_causal_masked_attention():
|
|
164
169
|
)
|
165
170
|
|
166
171
|
torch.library.define(
|
167
|
-
"rbln_custom_ops::
|
168
|
-
"(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e) -> Tensor[]",
|
172
|
+
"rbln_custom_ops::paged_causal_attn_prefill",
|
173
|
+
"(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor[]",
|
169
174
|
)
|
170
175
|
|
171
|
-
@torch.library.impl("rbln_custom_ops::
|
172
|
-
def attn_prefill_cpu(q, k, v, kcache, vcache,
|
176
|
+
@torch.library.impl("rbln_custom_ops::paged_causal_attn_prefill", "cpu")
|
177
|
+
def attn_prefill_cpu(q, k, v, kcache, vcache, seq, scale, block_table, block_size):
|
173
178
|
"""Defines the computation pattern for prefill phase attention with KV cache updates.
|
174
179
|
|
175
180
|
IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
|
@@ -186,8 +191,10 @@ def register_rbln_custom_causal_masked_attention():
|
|
186
191
|
- kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
|
187
192
|
- vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
|
188
193
|
- batch: [1] - Batch index for cache access
|
189
|
-
- seq: [1] - Starting sequence position
|
194
|
+
- seq: [1, 1] - Starting sequence position
|
190
195
|
- scale: [] - Attention scale factor
|
196
|
+
- block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
|
197
|
+
- block_size: [] - Number of tokens per block
|
191
198
|
|
192
199
|
Returns:
|
193
200
|
Tuple[Tensor, Tensor, Tensor]:
|
@@ -197,20 +204,20 @@ def register_rbln_custom_causal_masked_attention():
|
|
197
204
|
"""
|
198
205
|
return q, kcache, vcache
|
199
206
|
|
200
|
-
@register_fake("rbln_custom_ops::
|
201
|
-
def attn_prefill_abstract(q, k, v, kcache, vcache,
|
207
|
+
@register_fake("rbln_custom_ops::paged_causal_attn_prefill")
|
208
|
+
def attn_prefill_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size):
|
202
209
|
return q, kcache, vcache
|
203
210
|
|
204
211
|
|
205
212
|
@lru_cache
|
206
|
-
def
|
213
|
+
def register_rbln_custom_add_softmax_attention():
|
207
214
|
torch.library.define(
|
208
|
-
"rbln_custom_ops::
|
215
|
+
"rbln_custom_ops::add_softmax_attn_decode",
|
209
216
|
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor[]",
|
210
217
|
)
|
211
218
|
|
212
|
-
@torch.library.impl("rbln_custom_ops::
|
213
|
-
def
|
219
|
+
@torch.library.impl("rbln_custom_ops::add_softmax_attn_decode", "cpu")
|
220
|
+
def add_softmax_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale):
|
214
221
|
"""Defines the computation pattern for fused attention with KV cache updates.
|
215
222
|
|
216
223
|
IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
|
@@ -244,57 +251,10 @@ def register_rbln_custom_attention_add_softmax():
|
|
244
251
|
torch.empty(*vcache.shape, device=vcache.device),
|
245
252
|
)
|
246
253
|
|
247
|
-
@register_fake("rbln_custom_ops::
|
248
|
-
def
|
254
|
+
@register_fake("rbln_custom_ops::add_softmax_attn_decode")
|
255
|
+
def add_softmax_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, partition):
|
249
256
|
return (
|
250
257
|
q,
|
251
258
|
torch.empty(*kcache.shape, device=kcache.device),
|
252
259
|
torch.empty(*vcache.shape, device=vcache.device),
|
253
260
|
)
|
254
|
-
|
255
|
-
torch.library.define(
|
256
|
-
"rbln_custom_ops::attn_prefill_add_softmax",
|
257
|
-
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e) -> Tensor[]",
|
258
|
-
)
|
259
|
-
|
260
|
-
@torch.library.impl("rbln_custom_ops::attn_prefill_add_softmax", "cpu")
|
261
|
-
def attn_prefill_add_softmax_cpu(q, k, v, mask, kcache, vcache, batch, seq, scale):
|
262
|
-
"""Defines the computation pattern for prefill phase attention with KV cache updates.
|
263
|
-
|
264
|
-
IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
|
265
|
-
a single optimized NPU operation. It is NOT meant for CPU execution.
|
266
|
-
|
267
|
-
Key differences from decode pattern:
|
268
|
-
- Handles prefill phase with multiple input tokens
|
269
|
-
- Takes explicit batch index for continuous batching
|
270
|
-
|
271
|
-
Expected tensor shapes:
|
272
|
-
- q: [batch=1, n_heads, n_groups, seq_len, head_dim] - Query states for multiple tokens
|
273
|
-
- k: [batch=1, n_heads, 1, seq_len, head_dim] - Key states for current input
|
274
|
-
- v: [batch=1, n_heads, 1, seq_len, head_dim] - Value states for current input
|
275
|
-
- mask: [batch=1, 1, 1, seq_len, max_seq_len] - Attention mask
|
276
|
-
- kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
|
277
|
-
- vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
|
278
|
-
- batch: [1] - Batch index for cache access
|
279
|
-
- seq: [1] - Starting sequence position
|
280
|
-
- scale: [] - Attention scale factor
|
281
|
-
|
282
|
-
Returns:
|
283
|
-
Tuple[Tensor, Tensor, Tensor]:
|
284
|
-
- attn_output: [batch=1, n_heads, seq_len, 1, head_dim] - Attention output
|
285
|
-
- empty_kcache: Same shape as input kcache - Placeholder for compiler
|
286
|
-
- empty_vcache: Same shape as input vcache - Placeholder for compiler
|
287
|
-
"""
|
288
|
-
return (
|
289
|
-
q,
|
290
|
-
torch.empty(1, *kcache.shape[1:], device=kcache.device),
|
291
|
-
torch.empty(1, *vcache.shape[1:], device=vcache.device),
|
292
|
-
)
|
293
|
-
|
294
|
-
@register_fake("rbln_custom_ops::attn_prefill_add_softmax")
|
295
|
-
def attn_prefill_add_softmax_abstract(q, k, v, m, kcache, vcache, batch, seq, partition):
|
296
|
-
return (
|
297
|
-
q,
|
298
|
-
torch.empty(1, *kcache.shape[1:], device=kcache.device),
|
299
|
-
torch.empty(1, *vcache.shape[1:], device=vcache.device),
|
300
|
-
)
|
optimum/rbln/ops/flash_attn.py
CHANGED
@@ -25,22 +25,22 @@ else:
|
|
25
25
|
|
26
26
|
|
27
27
|
@lru_cache
|
28
|
-
def
|
28
|
+
def register_rbln_custom_paged_flash_attention():
|
29
29
|
torch.library.define(
|
30
|
-
"rbln_custom_ops::
|
31
|
-
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, int
|
30
|
+
"rbln_custom_ops::paged_flash_attn_decode",
|
31
|
+
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor[]",
|
32
32
|
)
|
33
33
|
|
34
|
-
@torch.library.impl("rbln_custom_ops::
|
35
|
-
def flash_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale, partition):
|
34
|
+
@torch.library.impl("rbln_custom_ops::paged_flash_attn_decode", "cpu")
|
35
|
+
def flash_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size, partition):
|
36
36
|
return (
|
37
37
|
q,
|
38
38
|
torch.empty(*kcache.shape, device=kcache.device),
|
39
39
|
torch.empty(*vcache.shape, device=vcache.device),
|
40
40
|
)
|
41
41
|
|
42
|
-
@register_fake("rbln_custom_ops::
|
43
|
-
def flash_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, scale, partition):
|
42
|
+
@register_fake("rbln_custom_ops::paged_flash_attn_decode")
|
43
|
+
def flash_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size, partition):
|
44
44
|
return (
|
45
45
|
q,
|
46
46
|
torch.empty(*kcache.shape, device=kcache.device),
|
@@ -48,36 +48,36 @@ def register_rbln_custom_flash_masked_attention():
|
|
48
48
|
)
|
49
49
|
|
50
50
|
torch.library.define(
|
51
|
-
"rbln_custom_ops::
|
52
|
-
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor[]",
|
51
|
+
"rbln_custom_ops::paged_flash_attn_prefill",
|
52
|
+
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor[]",
|
53
53
|
)
|
54
54
|
|
55
55
|
@torch.library.impl("rbln_custom_ops::flash_attn_prefill", "cpu")
|
56
|
-
def flash_attn_prefill_cpu(q, k, v, mask, kcache, vcache,
|
56
|
+
def flash_attn_prefill_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size, partition):
|
57
57
|
return q, kcache, vcache
|
58
58
|
|
59
|
-
@register_fake("rbln_custom_ops::
|
60
|
-
def flash_attn_prefill_abstract(q, k, v, m, kcache, vcache,
|
59
|
+
@register_fake("rbln_custom_ops::paged_flash_attn_prefill")
|
60
|
+
def flash_attn_prefill_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size, partition):
|
61
61
|
return q, kcache, vcache
|
62
62
|
|
63
63
|
|
64
64
|
@lru_cache
|
65
|
-
def
|
65
|
+
def register_rbln_custom_paged_flash_causal_attention():
|
66
66
|
torch.library.define(
|
67
|
-
"rbln_custom_ops::
|
68
|
-
"(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, int
|
67
|
+
"rbln_custom_ops::paged_flash_causal_attn_decode",
|
68
|
+
"(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor[]",
|
69
69
|
)
|
70
70
|
|
71
|
-
@torch.library.impl("rbln_custom_ops::
|
72
|
-
def flash_attn_decode_cpu(q, k, v, kcache, vcache, seq, scale, partition):
|
71
|
+
@torch.library.impl("rbln_custom_ops::paged_flash_causal_attn_decode", "cpu")
|
72
|
+
def flash_attn_decode_cpu(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
|
73
73
|
return (
|
74
74
|
q,
|
75
75
|
torch.empty(*kcache.shape, device=kcache.device),
|
76
76
|
torch.empty(*vcache.shape, device=vcache.device),
|
77
77
|
)
|
78
78
|
|
79
|
-
@register_fake("rbln_custom_ops::
|
80
|
-
def flash_attn_decode_abstract(q, k, v, kcache, vcache, seq, scale, partition):
|
79
|
+
@register_fake("rbln_custom_ops::paged_flash_causal_attn_decode")
|
80
|
+
def flash_attn_decode_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
|
81
81
|
return (
|
82
82
|
q,
|
83
83
|
torch.empty(*kcache.shape, device=kcache.device),
|
@@ -85,14 +85,14 @@ def register_rbln_custom_flash_causal_masked_attention():
|
|
85
85
|
)
|
86
86
|
|
87
87
|
torch.library.define(
|
88
|
-
"rbln_custom_ops::
|
89
|
-
"(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor[]",
|
88
|
+
"rbln_custom_ops::paged_flash_causal_attn_prefill",
|
89
|
+
"(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor[]",
|
90
90
|
)
|
91
91
|
|
92
|
-
@torch.library.impl("rbln_custom_ops::
|
93
|
-
def flash_attn_prefill_cpu(q, k, v, kcache, vcache,
|
92
|
+
@torch.library.impl("rbln_custom_ops::paged_flash_causal_attn_prefill", "cpu")
|
93
|
+
def flash_attn_prefill_cpu(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
|
94
94
|
return q, kcache, vcache
|
95
95
|
|
96
|
-
@register_fake("rbln_custom_ops::
|
97
|
-
def flash_attn_prefill_abstract(q, k, v, kcache, vcache,
|
96
|
+
@register_fake("rbln_custom_ops::paged_flash_causal_attn_prefill")
|
97
|
+
def flash_attn_prefill_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
|
98
98
|
return q, kcache, vcache
|
@@ -35,16 +35,16 @@ logger = logging.get_logger(__name__)
|
|
35
35
|
|
36
36
|
|
37
37
|
class BartWrapper:
|
38
|
-
def __init__(self, model: nn.Module, enc_max_seq_len: int):
|
38
|
+
def __init__(self, model: nn.Module, enc_max_seq_len: int, use_attention_mask: bool):
|
39
39
|
self.encoder = Seq2SeqEncoderWrapper(model, enc_max_seq_len)
|
40
|
-
self.decoder = BartDecoderWrapper(model)
|
40
|
+
self.decoder = BartDecoderWrapper(model, use_attention_mask=use_attention_mask)
|
41
41
|
|
42
42
|
|
43
43
|
class BartDecoderWrapper(Seq2SeqDecoderWrapper):
|
44
44
|
def convert_to_rbln_conditional_generation(self, model: nn.Module):
|
45
45
|
new_layers = []
|
46
46
|
for layer in model.get_decoder().layers:
|
47
|
-
self_attn = BartSelfAttention(layer.self_attn)
|
47
|
+
self_attn = BartSelfAttention(layer.self_attn, use_attention_mask=self.use_attention_mask)
|
48
48
|
new_layers.append(BartDecoderLayer(layer, self_attn))
|
49
49
|
|
50
50
|
decoder_model = BartDecoder(model.get_decoder(), new_layers)
|
@@ -69,7 +69,8 @@ class BartDecoder(Seq2SeqDecoder):
|
|
69
69
|
self.embed_scale = getattr(self._original_mod, "embed_scale", None)
|
70
70
|
|
71
71
|
def prepare_attn_mask(self, attention_mask, encoder_attention_mask, **kwargs):
|
72
|
-
attention_mask
|
72
|
+
if attention_mask is not None:
|
73
|
+
attention_mask = attention_mask[:, None, None, :]
|
73
74
|
encoder_attention_mask = _prepare_4d_attention_mask(encoder_attention_mask, torch.float32, tgt_len=1)
|
74
75
|
|
75
76
|
return attention_mask, encoder_attention_mask
|
@@ -134,7 +135,7 @@ class BartDecoderLayer(Seq2SeqDecoderLayer):
|
|
134
135
|
|
135
136
|
|
136
137
|
class BartSelfAttention(Seq2SeqSelfAttention):
|
137
|
-
def __post_init__(self):
|
138
|
+
def __post_init__(self, use_attention_mask: bool = True):
|
138
139
|
self.q_proj = self._original_mod.q_proj
|
139
140
|
self.k_proj = self._original_mod.k_proj
|
140
141
|
self.v_proj = self._original_mod.v_proj
|
@@ -142,7 +143,10 @@ class BartSelfAttention(Seq2SeqSelfAttention):
|
|
142
143
|
self.num_heads = self._original_mod.num_heads
|
143
144
|
self.head_dim = self._original_mod.embed_dim // self._original_mod.num_heads
|
144
145
|
self.scaling = self.head_dim**-0.5
|
145
|
-
|
146
|
+
if use_attention_mask:
|
147
|
+
self.attn_decode = torch.ops.rbln_custom_ops.paged_attn_decode
|
148
|
+
else:
|
149
|
+
self.attn_decode = torch.ops.rbln_custom_ops.paged_causal_attn_decode
|
146
150
|
|
147
151
|
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
148
152
|
query_states = self.q_proj(hidden_states) * self.scaling
|
@@ -113,7 +113,9 @@ class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
|
113
113
|
enc_max_seq_len = (
|
114
114
|
rbln_config.model_cfg["enc_max_seq_len"] if "enc_max_seq_len" in rbln_config.model_cfg else 1024
|
115
115
|
)
|
116
|
-
|
116
|
+
use_attention_mask = rbln_config.model_cfg.get("use_attention_mask", False)
|
117
|
+
|
118
|
+
return BartWrapper(model, enc_max_seq_len=enc_max_seq_len, use_attention_mask=use_attention_mask)
|
117
119
|
|
118
120
|
def __getattr__(self, __name: str) -> Any:
|
119
121
|
def redirect(func):
|