optimum-rbln 0.7.3a1__py3-none-any.whl → 0.7.3a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.7.3a1'
21
- __version_tuple__ = version_tuple = (0, 7, 3)
20
+ __version__ = version = '0.7.3a3'
21
+ __version_tuple__ = version_tuple = (0, 7, 3, 'a3')
@@ -13,9 +13,9 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from .attn import (
16
- register_rbln_custom_attention_add_softmax,
17
- register_rbln_custom_causal_masked_attention,
18
- register_rbln_custom_masked_attention,
16
+ register_rbln_custom_add_softmax_attention,
17
+ register_rbln_custom_paged_attention,
18
+ register_rbln_custom_paged_causal_attention,
19
19
  )
20
- from .flash_attn import register_rbln_custom_flash_causal_masked_attention, register_rbln_custom_flash_masked_attention
20
+ from .flash_attn import register_rbln_custom_paged_flash_attention, register_rbln_custom_paged_flash_causal_attention
21
21
  from .kv_cache_update import register_rbln_custom_cache_update
optimum/rbln/ops/attn.py CHANGED
@@ -25,14 +25,14 @@ else:
25
25
 
26
26
 
27
27
  @lru_cache
28
- def register_rbln_custom_masked_attention():
28
+ def register_rbln_custom_paged_attention():
29
29
  torch.library.define(
30
- "rbln_custom_ops::masked_attn_decode",
31
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor[]",
30
+ "rbln_custom_ops::paged_attn_decode",
31
+ "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor[]",
32
32
  )
33
33
 
34
- @torch.library.impl("rbln_custom_ops::masked_attn_decode", "cpu")
35
- def attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale):
34
+ @torch.library.impl("rbln_custom_ops::paged_attn_decode", "cpu")
35
+ def attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size):
36
36
  """Defines the computation pattern for fused attention with KV cache updates.
37
37
 
38
38
  IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
@@ -51,8 +51,10 @@ def register_rbln_custom_masked_attention():
51
51
  - mask: [batch=1, n_heads, 1, 1, max_seq_len] - Attention mask
52
52
  - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
53
53
  - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
54
- - seq: [1] - Current sequence position
54
+ - seq: [1, 1] - Current sequence position
55
55
  - scale: [] - Attention scale factor
56
+ - block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
57
+ - block_size: [] - Number of tokens per block
56
58
 
57
59
  Returns:
58
60
  Tuple[Tensor, Tensor, Tensor]:
@@ -66,8 +68,8 @@ def register_rbln_custom_masked_attention():
66
68
  torch.empty(*vcache.shape, device=vcache.device),
67
69
  )
68
70
 
69
- @register_fake("rbln_custom_ops::masked_attn_decode")
70
- def attn_decode_abstract(q, k, v, m, kcache, vcache, seq, partition):
71
+ @register_fake("rbln_custom_ops::paged_attn_decode")
72
+ def attn_decode_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size):
71
73
  return (
72
74
  q,
73
75
  torch.empty(*kcache.shape, device=kcache.device),
@@ -75,12 +77,12 @@ def register_rbln_custom_masked_attention():
75
77
  )
76
78
 
77
79
  torch.library.define(
78
- "rbln_custom_ops::masked_attn_prefill",
79
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e) -> Tensor[]",
80
+ "rbln_custom_ops::paged_attn_prefill",
81
+ "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor[]",
80
82
  )
81
83
 
82
- @torch.library.impl("rbln_custom_ops::masked_attn_prefill", "cpu")
83
- def attn_prefill_cpu(q, k, v, mask, kcache, vcache, batch, seq, scale):
84
+ @torch.library.impl("rbln_custom_ops::paged_attn_prefill", "cpu")
85
+ def attn_prefill_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size):
84
86
  """Defines the computation pattern for prefill phase attention with KV cache updates.
85
87
 
86
88
  IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
@@ -97,9 +99,10 @@ def register_rbln_custom_masked_attention():
97
99
  - mask: [batch=1, 1, 1, seq_len, max_seq_len] - Attention mask
98
100
  - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
99
101
  - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
100
- - batch: [1] - Batch index for cache access
101
- - seq: [1] - Starting sequence position
102
+ - seq: [1, 1] - Starting sequence position
102
103
  - scale: [] - Attention scale factor
104
+ - block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
105
+ - block_size: [] - Number of tokens per block
103
106
 
104
107
  Returns:
105
108
  Tuple[Tensor, Tensor, Tensor]:
@@ -109,20 +112,20 @@ def register_rbln_custom_masked_attention():
109
112
  """
110
113
  return q, kcache, vcache
111
114
 
112
- @register_fake("rbln_custom_ops::masked_attn_prefill")
113
- def attn_prefill_abstract(q, k, v, m, kcache, vcache, batch, seq, partition):
115
+ @register_fake("rbln_custom_ops::paged_attn_prefill")
116
+ def attn_prefill_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size):
114
117
  return q, kcache, vcache
115
118
 
116
119
 
117
120
  @lru_cache
118
- def register_rbln_custom_causal_masked_attention():
121
+ def register_rbln_custom_paged_causal_attention():
119
122
  torch.library.define(
120
- "rbln_custom_ops::causal_masked_attn_decode",
121
- "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor[]",
123
+ "rbln_custom_ops::paged_causal_attn_decode",
124
+ "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor[]",
122
125
  )
123
126
 
124
- @torch.library.impl("rbln_custom_ops::causal_masked_attn_decode", "cpu")
125
- def attn_decode_cpu(q, k, v, kcache, vcache, seq, scale):
127
+ @torch.library.impl("rbln_custom_ops::paged_causal_attn_decode", "cpu")
128
+ def attn_decode_cpu(q, k, v, kcache, vcache, seq, scale, block_table, block_size):
126
129
  """Defines the computation pattern for fused attention with KV cache updates.
127
130
 
128
131
  IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
@@ -140,8 +143,10 @@ def register_rbln_custom_causal_masked_attention():
140
143
  - v: [batch=1, n_heads, 1, 1, head_dim] - Value states for current input
141
144
  - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
142
145
  - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
143
- - seq: [1] - Current sequence position
146
+ - seq: [1, 1] - Starting sequence position
144
147
  - scale: [] - Attention scale factor
148
+ - block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
149
+ - block_size: [] - Number of tokens per block
145
150
 
146
151
  Returns:
147
152
  Tuple[Tensor, Tensor, Tensor]:
@@ -155,8 +160,8 @@ def register_rbln_custom_causal_masked_attention():
155
160
  torch.empty(*vcache.shape, device=vcache.device),
156
161
  )
157
162
 
158
- @register_fake("rbln_custom_ops::causal_masked_attn_decode")
159
- def attn_decode_abstract(q, k, v, kcache, vcache, seq, partition):
163
+ @register_fake("rbln_custom_ops::paged_causal_attn_decode")
164
+ def attn_decode_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size):
160
165
  return (
161
166
  q,
162
167
  torch.empty(*kcache.shape, device=kcache.device),
@@ -164,12 +169,12 @@ def register_rbln_custom_causal_masked_attention():
164
169
  )
165
170
 
166
171
  torch.library.define(
167
- "rbln_custom_ops::causal_masked_attn_prefill",
168
- "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e) -> Tensor[]",
172
+ "rbln_custom_ops::paged_causal_attn_prefill",
173
+ "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor[]",
169
174
  )
170
175
 
171
- @torch.library.impl("rbln_custom_ops::causal_masked_attn_prefill", "cpu")
172
- def attn_prefill_cpu(q, k, v, kcache, vcache, batch, seq, scale):
176
+ @torch.library.impl("rbln_custom_ops::paged_causal_attn_prefill", "cpu")
177
+ def attn_prefill_cpu(q, k, v, kcache, vcache, seq, scale, block_table, block_size):
173
178
  """Defines the computation pattern for prefill phase attention with KV cache updates.
174
179
 
175
180
  IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
@@ -186,8 +191,10 @@ def register_rbln_custom_causal_masked_attention():
186
191
  - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
187
192
  - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
188
193
  - batch: [1] - Batch index for cache access
189
- - seq: [1] - Starting sequence position
194
+ - seq: [1, 1] - Starting sequence position
190
195
  - scale: [] - Attention scale factor
196
+ - block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
197
+ - block_size: [] - Number of tokens per block
191
198
 
192
199
  Returns:
193
200
  Tuple[Tensor, Tensor, Tensor]:
@@ -197,20 +204,20 @@ def register_rbln_custom_causal_masked_attention():
197
204
  """
198
205
  return q, kcache, vcache
199
206
 
200
- @register_fake("rbln_custom_ops::causal_masked_attn_prefill")
201
- def attn_prefill_abstract(q, k, v, kcache, vcache, batch, seq, partition):
207
+ @register_fake("rbln_custom_ops::paged_causal_attn_prefill")
208
+ def attn_prefill_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size):
202
209
  return q, kcache, vcache
203
210
 
204
211
 
205
212
  @lru_cache
206
- def register_rbln_custom_attention_add_softmax():
213
+ def register_rbln_custom_add_softmax_attention():
207
214
  torch.library.define(
208
- "rbln_custom_ops::attn_decode_add_softmax",
215
+ "rbln_custom_ops::add_softmax_attn_decode",
209
216
  "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor[]",
210
217
  )
211
218
 
212
- @torch.library.impl("rbln_custom_ops::attn_decode_add_softmax", "cpu")
213
- def attn_decode_add_softmax_cpu(q, k, v, mask, kcache, vcache, seq, scale):
219
+ @torch.library.impl("rbln_custom_ops::add_softmax_attn_decode", "cpu")
220
+ def add_softmax_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale):
214
221
  """Defines the computation pattern for fused attention with KV cache updates.
215
222
 
216
223
  IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
@@ -244,57 +251,10 @@ def register_rbln_custom_attention_add_softmax():
244
251
  torch.empty(*vcache.shape, device=vcache.device),
245
252
  )
246
253
 
247
- @register_fake("rbln_custom_ops::attn_decode_add_softmax")
248
- def attn_decode_add_softmax_abstract(q, k, v, m, kcache, vcache, seq, partition):
254
+ @register_fake("rbln_custom_ops::add_softmax_attn_decode")
255
+ def add_softmax_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, partition):
249
256
  return (
250
257
  q,
251
258
  torch.empty(*kcache.shape, device=kcache.device),
252
259
  torch.empty(*vcache.shape, device=vcache.device),
253
260
  )
254
-
255
- torch.library.define(
256
- "rbln_custom_ops::attn_prefill_add_softmax",
257
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e) -> Tensor[]",
258
- )
259
-
260
- @torch.library.impl("rbln_custom_ops::attn_prefill_add_softmax", "cpu")
261
- def attn_prefill_add_softmax_cpu(q, k, v, mask, kcache, vcache, batch, seq, scale):
262
- """Defines the computation pattern for prefill phase attention with KV cache updates.
263
-
264
- IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
265
- a single optimized NPU operation. It is NOT meant for CPU execution.
266
-
267
- Key differences from decode pattern:
268
- - Handles prefill phase with multiple input tokens
269
- - Takes explicit batch index for continuous batching
270
-
271
- Expected tensor shapes:
272
- - q: [batch=1, n_heads, n_groups, seq_len, head_dim] - Query states for multiple tokens
273
- - k: [batch=1, n_heads, 1, seq_len, head_dim] - Key states for current input
274
- - v: [batch=1, n_heads, 1, seq_len, head_dim] - Value states for current input
275
- - mask: [batch=1, 1, 1, seq_len, max_seq_len] - Attention mask
276
- - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
277
- - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
278
- - batch: [1] - Batch index for cache access
279
- - seq: [1] - Starting sequence position
280
- - scale: [] - Attention scale factor
281
-
282
- Returns:
283
- Tuple[Tensor, Tensor, Tensor]:
284
- - attn_output: [batch=1, n_heads, seq_len, 1, head_dim] - Attention output
285
- - empty_kcache: Same shape as input kcache - Placeholder for compiler
286
- - empty_vcache: Same shape as input vcache - Placeholder for compiler
287
- """
288
- return (
289
- q,
290
- torch.empty(1, *kcache.shape[1:], device=kcache.device),
291
- torch.empty(1, *vcache.shape[1:], device=vcache.device),
292
- )
293
-
294
- @register_fake("rbln_custom_ops::attn_prefill_add_softmax")
295
- def attn_prefill_add_softmax_abstract(q, k, v, m, kcache, vcache, batch, seq, partition):
296
- return (
297
- q,
298
- torch.empty(1, *kcache.shape[1:], device=kcache.device),
299
- torch.empty(1, *vcache.shape[1:], device=vcache.device),
300
- )
@@ -25,22 +25,22 @@ else:
25
25
 
26
26
 
27
27
  @lru_cache
28
- def register_rbln_custom_flash_masked_attention():
28
+ def register_rbln_custom_paged_flash_attention():
29
29
  torch.library.define(
30
- "rbln_custom_ops::flash_masked_attn_decode",
31
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, int e) -> Tensor[]",
30
+ "rbln_custom_ops::paged_flash_attn_decode",
31
+ "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor[]",
32
32
  )
33
33
 
34
- @torch.library.impl("rbln_custom_ops::flash_masked_attn_decode", "cpu")
35
- def flash_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale, partition):
34
+ @torch.library.impl("rbln_custom_ops::paged_flash_attn_decode", "cpu")
35
+ def flash_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size, partition):
36
36
  return (
37
37
  q,
38
38
  torch.empty(*kcache.shape, device=kcache.device),
39
39
  torch.empty(*vcache.shape, device=vcache.device),
40
40
  )
41
41
 
42
- @register_fake("rbln_custom_ops::flash_masked_attn_decode")
43
- def flash_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, scale, partition):
42
+ @register_fake("rbln_custom_ops::paged_flash_attn_decode")
43
+ def flash_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size, partition):
44
44
  return (
45
45
  q,
46
46
  torch.empty(*kcache.shape, device=kcache.device),
@@ -48,36 +48,36 @@ def register_rbln_custom_flash_masked_attention():
48
48
  )
49
49
 
50
50
  torch.library.define(
51
- "rbln_custom_ops::flash_masked_attn_prefill",
52
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor[]",
51
+ "rbln_custom_ops::paged_flash_attn_prefill",
52
+ "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor[]",
53
53
  )
54
54
 
55
55
  @torch.library.impl("rbln_custom_ops::flash_attn_prefill", "cpu")
56
- def flash_attn_prefill_cpu(q, k, v, mask, kcache, vcache, batch, seq, scale, partition):
56
+ def flash_attn_prefill_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size, partition):
57
57
  return q, kcache, vcache
58
58
 
59
- @register_fake("rbln_custom_ops::flash_masked_attn_prefill")
60
- def flash_attn_prefill_abstract(q, k, v, m, kcache, vcache, batch, seq, scale, partition):
59
+ @register_fake("rbln_custom_ops::paged_flash_attn_prefill")
60
+ def flash_attn_prefill_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size, partition):
61
61
  return q, kcache, vcache
62
62
 
63
63
 
64
64
  @lru_cache
65
- def register_rbln_custom_flash_causal_masked_attention():
65
+ def register_rbln_custom_paged_flash_causal_attention():
66
66
  torch.library.define(
67
- "rbln_custom_ops::flash_causal_masked_attn_decode",
68
- "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, int e) -> Tensor[]",
67
+ "rbln_custom_ops::paged_flash_causal_attn_decode",
68
+ "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor[]",
69
69
  )
70
70
 
71
- @torch.library.impl("rbln_custom_ops::flash_causal_masked_attn_decode", "cpu")
72
- def flash_attn_decode_cpu(q, k, v, kcache, vcache, seq, scale, partition):
71
+ @torch.library.impl("rbln_custom_ops::paged_flash_causal_attn_decode", "cpu")
72
+ def flash_attn_decode_cpu(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
73
73
  return (
74
74
  q,
75
75
  torch.empty(*kcache.shape, device=kcache.device),
76
76
  torch.empty(*vcache.shape, device=vcache.device),
77
77
  )
78
78
 
79
- @register_fake("rbln_custom_ops::flash_causal_masked_attn_decode")
80
- def flash_attn_decode_abstract(q, k, v, kcache, vcache, seq, scale, partition):
79
+ @register_fake("rbln_custom_ops::paged_flash_causal_attn_decode")
80
+ def flash_attn_decode_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
81
81
  return (
82
82
  q,
83
83
  torch.empty(*kcache.shape, device=kcache.device),
@@ -85,14 +85,14 @@ def register_rbln_custom_flash_causal_masked_attention():
85
85
  )
86
86
 
87
87
  torch.library.define(
88
- "rbln_custom_ops::flash_causal_masked_attn_prefill",
89
- "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor[]",
88
+ "rbln_custom_ops::paged_flash_causal_attn_prefill",
89
+ "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor[]",
90
90
  )
91
91
 
92
- @torch.library.impl("rbln_custom_ops::flash_causal_masked_attn_prefill", "cpu")
93
- def flash_attn_prefill_cpu(q, k, v, kcache, vcache, batch, seq, scale, partition):
92
+ @torch.library.impl("rbln_custom_ops::paged_flash_causal_attn_prefill", "cpu")
93
+ def flash_attn_prefill_cpu(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
94
94
  return q, kcache, vcache
95
95
 
96
- @register_fake("rbln_custom_ops::flash_causal_masked_attn_prefill")
97
- def flash_attn_prefill_abstract(q, k, v, kcache, vcache, batch, seq, scale, partition):
96
+ @register_fake("rbln_custom_ops::paged_flash_causal_attn_prefill")
97
+ def flash_attn_prefill_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
98
98
  return q, kcache, vcache
@@ -35,16 +35,16 @@ logger = logging.get_logger(__name__)
35
35
 
36
36
 
37
37
  class BartWrapper:
38
- def __init__(self, model: nn.Module, enc_max_seq_len: int):
38
+ def __init__(self, model: nn.Module, enc_max_seq_len: int, use_attention_mask: bool):
39
39
  self.encoder = Seq2SeqEncoderWrapper(model, enc_max_seq_len)
40
- self.decoder = BartDecoderWrapper(model)
40
+ self.decoder = BartDecoderWrapper(model, use_attention_mask=use_attention_mask)
41
41
 
42
42
 
43
43
  class BartDecoderWrapper(Seq2SeqDecoderWrapper):
44
44
  def convert_to_rbln_conditional_generation(self, model: nn.Module):
45
45
  new_layers = []
46
46
  for layer in model.get_decoder().layers:
47
- self_attn = BartSelfAttention(layer.self_attn)
47
+ self_attn = BartSelfAttention(layer.self_attn, use_attention_mask=self.use_attention_mask)
48
48
  new_layers.append(BartDecoderLayer(layer, self_attn))
49
49
 
50
50
  decoder_model = BartDecoder(model.get_decoder(), new_layers)
@@ -69,7 +69,8 @@ class BartDecoder(Seq2SeqDecoder):
69
69
  self.embed_scale = getattr(self._original_mod, "embed_scale", None)
70
70
 
71
71
  def prepare_attn_mask(self, attention_mask, encoder_attention_mask, **kwargs):
72
- attention_mask = attention_mask[:, None, None, :]
72
+ if attention_mask is not None:
73
+ attention_mask = attention_mask[:, None, None, :]
73
74
  encoder_attention_mask = _prepare_4d_attention_mask(encoder_attention_mask, torch.float32, tgt_len=1)
74
75
 
75
76
  return attention_mask, encoder_attention_mask
@@ -134,7 +135,7 @@ class BartDecoderLayer(Seq2SeqDecoderLayer):
134
135
 
135
136
 
136
137
  class BartSelfAttention(Seq2SeqSelfAttention):
137
- def __post_init__(self):
138
+ def __post_init__(self, use_attention_mask: bool = True):
138
139
  self.q_proj = self._original_mod.q_proj
139
140
  self.k_proj = self._original_mod.k_proj
140
141
  self.v_proj = self._original_mod.v_proj
@@ -142,7 +143,10 @@ class BartSelfAttention(Seq2SeqSelfAttention):
142
143
  self.num_heads = self._original_mod.num_heads
143
144
  self.head_dim = self._original_mod.embed_dim // self._original_mod.num_heads
144
145
  self.scaling = self.head_dim**-0.5
145
- self.attn_decode = torch.ops.rbln_custom_ops.masked_attn_decode
146
+ if use_attention_mask:
147
+ self.attn_decode = torch.ops.rbln_custom_ops.paged_attn_decode
148
+ else:
149
+ self.attn_decode = torch.ops.rbln_custom_ops.paged_causal_attn_decode
146
150
 
147
151
  def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
148
152
  query_states = self.q_proj(hidden_states) * self.scaling
@@ -113,7 +113,9 @@ class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
113
113
  enc_max_seq_len = (
114
114
  rbln_config.model_cfg["enc_max_seq_len"] if "enc_max_seq_len" in rbln_config.model_cfg else 1024
115
115
  )
116
- return BartWrapper(model, enc_max_seq_len=enc_max_seq_len)
116
+ use_attention_mask = rbln_config.model_cfg.get("use_attention_mask", False)
117
+
118
+ return BartWrapper(model, enc_max_seq_len=enc_max_seq_len, use_attention_mask=use_attention_mask)
117
119
 
118
120
  def __getattr__(self, __name: str) -> Any:
119
121
  def redirect(func):