optimum-rbln 0.7.2rc1__py3-none-any.whl → 0.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. optimum/rbln/__init__.py +8 -0
  2. optimum/rbln/__version__.py +9 -4
  3. optimum/rbln/diffusers/__init__.py +8 -0
  4. optimum/rbln/diffusers/modeling_diffusers.py +103 -109
  5. optimum/rbln/diffusers/models/autoencoders/vq_model.py +11 -3
  6. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +15 -8
  7. optimum/rbln/diffusers/pipelines/__init__.py +8 -0
  8. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +7 -1
  9. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +25 -0
  10. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +107 -1
  11. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +25 -0
  12. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +3 -0
  13. optimum/rbln/modeling.py +4 -1
  14. optimum/rbln/modeling_base.py +16 -3
  15. optimum/rbln/ops/__init__.py +6 -2
  16. optimum/rbln/ops/attn.py +94 -85
  17. optimum/rbln/ops/flash_attn.py +46 -25
  18. optimum/rbln/ops/kv_cache_update.py +4 -4
  19. optimum/rbln/transformers/modeling_generic.py +3 -3
  20. optimum/rbln/transformers/models/bart/bart_architecture.py +10 -6
  21. optimum/rbln/transformers/models/bart/modeling_bart.py +6 -2
  22. optimum/rbln/transformers/models/bert/modeling_bert.py +1 -1
  23. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +264 -133
  24. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +276 -29
  25. optimum/rbln/transformers/models/exaone/exaone_architecture.py +11 -4
  26. optimum/rbln/transformers/models/gemma/gemma_architecture.py +11 -4
  27. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +5 -3
  28. optimum/rbln/transformers/models/midm/midm_architecture.py +5 -3
  29. optimum/rbln/transformers/models/phi/phi_architecture.py +9 -7
  30. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +50 -13
  31. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +60 -36
  32. optimum/rbln/transformers/models/t5/modeling_t5.py +3 -1
  33. optimum/rbln/transformers/models/t5/t5_architecture.py +65 -3
  34. optimum/rbln/transformers/models/whisper/whisper_architecture.py +26 -36
  35. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +1 -14
  36. optimum/rbln/utils/import_utils.py +7 -0
  37. {optimum_rbln-0.7.2rc1.dist-info → optimum_rbln-0.7.3.dist-info}/METADATA +1 -1
  38. {optimum_rbln-0.7.2rc1.dist-info → optimum_rbln-0.7.3.dist-info}/RECORD +40 -38
  39. {optimum_rbln-0.7.2rc1.dist-info → optimum_rbln-0.7.3.dist-info}/WHEEL +0 -0
  40. {optimum_rbln-0.7.2rc1.dist-info → optimum_rbln-0.7.3.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/modeling.py CHANGED
@@ -134,6 +134,9 @@ class RBLNModel(RBLNBaseModel):
134
134
  for preprocessor in preprocessors:
135
135
  preprocessor.save_pretrained(save_dir_path / subfolder)
136
136
 
137
+ # ad-hoc
138
+ rbln_kwargs["n_model_params"] = sum(p.numel() for p in model.parameters())
139
+
137
140
  # Get compilation arguments (e.g. input_info)
138
141
  rbln_config: RBLNConfig = cls.get_rbln_config(
139
142
  preprocessors=preprocessors, model_config=config, rbln_kwargs=rbln_kwargs
@@ -196,7 +199,7 @@ class RBLNModel(RBLNBaseModel):
196
199
  **kwargs,
197
200
  ) -> "PreTrainedModel":
198
201
  kwargs = cls.update_kwargs(kwargs)
199
- return cls.hf_class.from_pretrained(
202
+ return cls.get_hf_class().from_pretrained(
200
203
  model_id,
201
204
  subfolder=subfolder,
202
205
  revision=revision,
@@ -295,6 +295,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
295
295
  ):
296
296
  if isinstance(model_save_dir, str):
297
297
  model_save_dir = Path(model_save_dir)
298
+
298
299
  # FIXME:: Should we convert it?
299
300
  compiled_model_names = [cfg.compiled_model_name for cfg in rbln_config.compile_cfgs]
300
301
  rbln_compiled_models = [rbln_compiled_models[cm_name] for cm_name in compiled_model_names]
@@ -389,8 +390,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
389
390
  return rbln_config
390
391
 
391
392
  @classmethod
392
- @property
393
- def hf_class(cls):
393
+ def get_hf_class(cls):
394
394
  """
395
395
  Lazily loads and caches the corresponding Hugging Face model class.
396
396
  Removes 'RBLN' prefix from the class name to get the original class name
@@ -416,7 +416,20 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
416
416
  return self.forward(*args, **kwargs)
417
417
 
418
418
  def __repr__(self):
419
- return repr(self.model) + repr(self.rbln_submodules)
419
+ has_submodules = len(self.rbln_submodules) > 0
420
+ repr_str: str = f"<{self.__class__.__name__}>\n"
421
+ repr_str += f"- Total {len(self.model)} Runtimes"
422
+ repr_str += f" and {len(self.rbln_submodules)} Submodules\n" if has_submodules else "\n"
423
+ repr_str += "[Runtimes]\n"
424
+ repr_str += "\n".join([repr(model) for model in self.model])
425
+ repr_str += "\n"
426
+
427
+ if has_submodules > 0:
428
+ for i, submodule in enumerate(self.rbln_submodules):
429
+ repr_str += f"[Submodules {i} : {self._rbln_submodules[i]['name']}]\n"
430
+ repr_str += repr(submodule) + "\n"
431
+
432
+ return repr_str
420
433
 
421
434
  def __post_init__(self, **kwargs):
422
435
  pass
@@ -12,6 +12,10 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .attn import register_rbln_custom_attention, register_rbln_custom_attention_add_softmax
16
- from .flash_attn import register_rbln_custom_flash_attention
15
+ from .attn import (
16
+ register_rbln_custom_add_softmax_attention,
17
+ register_rbln_custom_paged_attention,
18
+ register_rbln_custom_paged_causal_attention,
19
+ )
20
+ from .flash_attn import register_rbln_custom_paged_flash_attention, register_rbln_custom_paged_flash_causal_attention
17
21
  from .kv_cache_update import register_rbln_custom_cache_update
optimum/rbln/ops/attn.py CHANGED
@@ -25,14 +25,14 @@ else:
25
25
 
26
26
 
27
27
  @lru_cache
28
- def register_rbln_custom_attention():
28
+ def register_rbln_custom_paged_attention():
29
29
  torch.library.define(
30
- "rbln_custom_ops::attn_decode",
31
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor[]",
30
+ "rbln_custom_ops::paged_attn_decode",
31
+ "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor",
32
32
  )
33
33
 
34
- @torch.library.impl("rbln_custom_ops::attn_decode", "cpu")
35
- def attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale):
34
+ @torch.library.impl("rbln_custom_ops::paged_attn_decode", "cpu")
35
+ def attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size):
36
36
  """Defines the computation pattern for fused attention with KV cache updates.
37
37
 
38
38
  IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
@@ -51,36 +51,27 @@ def register_rbln_custom_attention():
51
51
  - mask: [batch=1, n_heads, 1, 1, max_seq_len] - Attention mask
52
52
  - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
53
53
  - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
54
- - seq: [1] - Current sequence position
54
+ - seq: [1, 1] - Current sequence position
55
55
  - scale: [] - Attention scale factor
56
+ - block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
57
+ - block_size: [] - Number of tokens per block
56
58
 
57
59
  Returns:
58
- Tuple[Tensor, Tensor, Tensor]:
59
- - attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
60
- - kcache: Same shape as input kcache, batch=1 - Placeholder for compiler
61
- - vcache: Same shape as input vcache, batch=1 - Placeholder for compiler
60
+ Tensor: attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
62
61
  """
63
- return (
64
- q,
65
- torch.empty(*kcache.shape, device=kcache.device),
66
- torch.empty(*vcache.shape, device=vcache.device),
67
- )
68
-
69
- @register_fake("rbln_custom_ops::attn_decode")
70
- def attn_decode_abstract(q, k, v, m, kcache, vcache, seq, partition):
71
- return (
72
- q,
73
- torch.empty(*kcache.shape, device=kcache.device),
74
- torch.empty(*vcache.shape, device=vcache.device),
75
- )
62
+ return q
63
+
64
+ @register_fake("rbln_custom_ops::paged_attn_decode")
65
+ def attn_decode_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size):
66
+ return q
76
67
 
77
68
  torch.library.define(
78
- "rbln_custom_ops::attn_prefill",
79
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e) -> Tensor[]",
69
+ "rbln_custom_ops::paged_attn_prefill",
70
+ "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor",
80
71
  )
81
72
 
82
- @torch.library.impl("rbln_custom_ops::attn_prefill", "cpu")
83
- def attn_prefill_cpu(q, k, v, mask, kcache, vcache, batch, seq, scale):
73
+ @torch.library.impl("rbln_custom_ops::paged_attn_prefill", "cpu")
74
+ def attn_prefill_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size):
84
75
  """Defines the computation pattern for prefill phase attention with KV cache updates.
85
76
 
86
77
  IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
@@ -97,32 +88,30 @@ def register_rbln_custom_attention():
97
88
  - mask: [batch=1, 1, 1, seq_len, max_seq_len] - Attention mask
98
89
  - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
99
90
  - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
100
- - batch: [1] - Batch index for cache access
101
- - seq: [1] - Starting sequence position
91
+ - seq: [1, 1] - Starting sequence position
102
92
  - scale: [] - Attention scale factor
93
+ - block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
94
+ - block_size: [] - Number of tokens per block
103
95
 
104
96
  Returns:
105
- Tuple[Tensor, Tensor, Tensor]:
106
- - attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
107
- - empty_kcache: Same shape as input kcache - Placeholder for compiler
108
- - empty_vcache: Same shape as input vcache - Placeholder for compiler
97
+ Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
109
98
  """
110
- return q, kcache, vcache
99
+ return q
111
100
 
112
- @register_fake("rbln_custom_ops::attn_prefill")
113
- def attn_prefill_abstract(q, k, v, m, kcache, vcache, batch, seq, partition):
114
- return q, kcache, vcache
101
+ @register_fake("rbln_custom_ops::paged_attn_prefill")
102
+ def attn_prefill_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size):
103
+ return q
115
104
 
116
105
 
117
106
  @lru_cache
118
- def register_rbln_custom_attention_add_softmax():
107
+ def register_rbln_custom_paged_causal_attention():
119
108
  torch.library.define(
120
- "rbln_custom_ops::attn_decode_add_softmax",
121
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor[]",
109
+ "rbln_custom_ops::paged_causal_attn_decode",
110
+ "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor",
122
111
  )
123
112
 
124
- @torch.library.impl("rbln_custom_ops::attn_decode_add_softmax", "cpu")
125
- def attn_decode_add_softmax_cpu(q, k, v, mask, kcache, vcache, seq, scale):
113
+ @torch.library.impl("rbln_custom_ops::paged_causal_attn_decode", "cpu")
114
+ def attn_decode_cpu(q, k, v, kcache, vcache, seq, scale, block_table, block_size):
126
115
  """Defines the computation pattern for fused attention with KV cache updates.
127
116
 
128
117
  IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
@@ -131,46 +120,36 @@ def register_rbln_custom_attention_add_softmax():
131
120
  Pattern components that compiler fuses into a single op:
132
121
  1. KV cache updates with new key/value states
133
122
  2. Scaled dot-product attention computation
134
- 3. add-softmax operation
123
+ 3. Causal masked softmax operation
135
124
  4. Final attention output computation
136
125
 
137
126
  Expected tensor shapes:
138
127
  - q: [batch=1, n_heads, n_groups, 1, head_dim] - Query states for single token
139
128
  - k: [batch=1, n_heads, 1, 1, head_dim] - Key states for current input
140
129
  - v: [batch=1, n_heads, 1, 1, head_dim] - Value states for current input
141
- - mask: [batch=1, n_heads, 1, 1, max_seq_len] - Attention mask
142
130
  - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
143
131
  - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
144
- - seq: [1] - Current sequence position
132
+ - seq: [1, 1] - Starting sequence position
145
133
  - scale: [] - Attention scale factor
134
+ - block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
135
+ - block_size: [] - Number of tokens per block
146
136
 
147
137
  Returns:
148
- Tuple[Tensor, Tensor, Tensor]:
149
- - attn_output: [batch=1, n_heads, 1, 1, head_dim] - Attention output
150
- - kcache: Same shape as input kcache, batch=1 - Placeholder for compiler
151
- - vcache: Same shape as input vcache, batch=1 - Placeholder for compiler
138
+ Tensor: attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
152
139
  """
153
- return (
154
- q,
155
- torch.empty(*kcache.shape, device=kcache.device),
156
- torch.empty(*vcache.shape, device=vcache.device),
157
- )
158
-
159
- @register_fake("rbln_custom_ops::attn_decode_add_softmax")
160
- def attn_decode_add_softmax_abstract(q, k, v, m, kcache, vcache, seq, partition):
161
- return (
162
- q,
163
- torch.empty(*kcache.shape, device=kcache.device),
164
- torch.empty(*vcache.shape, device=vcache.device),
165
- )
140
+ return q
141
+
142
+ @register_fake("rbln_custom_ops::paged_causal_attn_decode")
143
+ def attn_decode_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size):
144
+ return q
166
145
 
167
146
  torch.library.define(
168
- "rbln_custom_ops::attn_prefill_add_softmax",
169
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e) -> Tensor[]",
147
+ "rbln_custom_ops::paged_causal_attn_prefill",
148
+ "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor",
170
149
  )
171
150
 
172
- @torch.library.impl("rbln_custom_ops::attn_prefill_add_softmax", "cpu")
173
- def attn_prefill_add_softmax_cpu(q, k, v, mask, kcache, vcache, batch, seq, scale):
151
+ @torch.library.impl("rbln_custom_ops::paged_causal_attn_prefill", "cpu")
152
+ def attn_prefill_cpu(q, k, v, kcache, vcache, seq, scale, block_table, block_size):
174
153
  """Defines the computation pattern for prefill phase attention with KV cache updates.
175
154
 
176
155
  IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
@@ -184,29 +163,59 @@ def register_rbln_custom_attention_add_softmax():
184
163
  - q: [batch=1, n_heads, n_groups, seq_len, head_dim] - Query states for multiple tokens
185
164
  - k: [batch=1, n_heads, 1, seq_len, head_dim] - Key states for current input
186
165
  - v: [batch=1, n_heads, 1, seq_len, head_dim] - Value states for current input
187
- - mask: [batch=1, 1, 1, seq_len, max_seq_len] - Attention mask
188
166
  - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
189
167
  - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
190
168
  - batch: [1] - Batch index for cache access
191
- - seq: [1] - Starting sequence position
169
+ - seq: [1, 1] - Starting sequence position
192
170
  - scale: [] - Attention scale factor
171
+ - block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
172
+ - block_size: [] - Number of tokens per block
193
173
 
194
174
  Returns:
195
- Tuple[Tensor, Tensor, Tensor]:
196
- - attn_output: [batch=1, n_heads, seq_len, 1, head_dim] - Attention output
197
- - empty_kcache: Same shape as input kcache - Placeholder for compiler
198
- - empty_vcache: Same shape as input vcache - Placeholder for compiler
175
+ Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
199
176
  """
200
- return (
201
- q,
202
- torch.empty(1, *kcache.shape[1:], device=kcache.device),
203
- torch.empty(1, *vcache.shape[1:], device=vcache.device),
204
- )
205
-
206
- @register_fake("rbln_custom_ops::attn_prefill_add_softmax")
207
- def attn_prefill_add_softmax_abstract(q, k, v, m, kcache, vcache, batch, seq, partition):
208
- return (
209
- q,
210
- torch.empty(1, *kcache.shape[1:], device=kcache.device),
211
- torch.empty(1, *vcache.shape[1:], device=vcache.device),
212
- )
177
+ return q
178
+
179
+ @register_fake("rbln_custom_ops::paged_causal_attn_prefill")
180
+ def attn_prefill_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size):
181
+ return q
182
+
183
+
184
+ @lru_cache
185
+ def register_rbln_custom_add_softmax_attention():
186
+ torch.library.define(
187
+ "rbln_custom_ops::add_softmax_attn_decode",
188
+ "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor",
189
+ )
190
+
191
+ @torch.library.impl("rbln_custom_ops::add_softmax_attn_decode", "cpu")
192
+ def add_softmax_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale):
193
+ """Defines the computation pattern for fused attention with KV cache updates.
194
+
195
+ IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
196
+ a single optimized NPU operation. It is NOT meant for CPU execution.
197
+
198
+ Pattern components that compiler fuses into a single op:
199
+ 1. KV cache updates with new key/value states
200
+ 2. Scaled dot-product attention computation
201
+ 3. add-softmax operation
202
+ 4. Final attention output computation
203
+
204
+ Expected tensor shapes:
205
+ - q: [batch=1, n_heads, n_groups, 1, head_dim] - Query states for single token
206
+ - k: [batch=1, n_heads, 1, 1, head_dim] - Key states for current input
207
+ - v: [batch=1, n_heads, 1, 1, head_dim] - Value states for current input
208
+ - mask: [batch=1, n_heads, 1, 1, max_seq_len] - Attention mask
209
+ - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
210
+ - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
211
+ - seq: [1] - Current sequence position
212
+ - scale: [] - Attention scale factor
213
+
214
+ Returns:
215
+ Tensor: attn_output: [batch=1, n_heads, 1, 1, head_dim] - Attention output
216
+ """
217
+ return q
218
+
219
+ @register_fake("rbln_custom_ops::add_softmax_attn_decode")
220
+ def add_softmax_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, partition):
221
+ return q
@@ -25,37 +25,58 @@ else:
25
25
 
26
26
 
27
27
  @lru_cache
28
- def register_rbln_custom_flash_attention():
28
+ def register_rbln_custom_paged_flash_attention():
29
29
  torch.library.define(
30
- "rbln_custom_ops::flash_attn_decode",
31
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, int e) -> Tensor[]",
30
+ "rbln_custom_ops::paged_flash_attn_decode",
31
+ "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
32
32
  )
33
33
 
34
- @torch.library.impl("rbln_custom_ops::flash_attn_decode", "cpu")
35
- def flash_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale, partition):
36
- return (
37
- q,
38
- torch.empty(*kcache.shape, device=kcache.device),
39
- torch.empty(*vcache.shape, device=vcache.device),
40
- )
41
-
42
- @register_fake("rbln_custom_ops::flash_attn_decode")
43
- def flash_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, scale, partition):
44
- return (
45
- q,
46
- torch.empty(*kcache.shape, device=kcache.device),
47
- torch.empty(*vcache.shape, device=vcache.device),
48
- )
34
+ @torch.library.impl("rbln_custom_ops::paged_flash_attn_decode", "cpu")
35
+ def flash_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size, partition):
36
+ return q
37
+
38
+ @register_fake("rbln_custom_ops::paged_flash_attn_decode")
39
+ def flash_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size, partition):
40
+ return q
49
41
 
50
42
  torch.library.define(
51
- "rbln_custom_ops::flash_attn_prefill",
52
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor[]",
43
+ "rbln_custom_ops::paged_flash_attn_prefill",
44
+ "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
53
45
  )
54
46
 
55
47
  @torch.library.impl("rbln_custom_ops::flash_attn_prefill", "cpu")
56
- def flash_attn_prefill_cpu(q, k, v, mask, kcache, vcache, batch, seq, scale, partition):
57
- return q, kcache, vcache
48
+ def flash_attn_prefill_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size, partition):
49
+ return q
50
+
51
+ @register_fake("rbln_custom_ops::paged_flash_attn_prefill")
52
+ def flash_attn_prefill_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size, partition):
53
+ return q
54
+
55
+
56
+ @lru_cache
57
+ def register_rbln_custom_paged_flash_causal_attention():
58
+ torch.library.define(
59
+ "rbln_custom_ops::paged_flash_causal_attn_decode",
60
+ "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
61
+ )
62
+
63
+ @torch.library.impl("rbln_custom_ops::paged_flash_causal_attn_decode", "cpu")
64
+ def flash_attn_decode_cpu(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
65
+ return q
66
+
67
+ @register_fake("rbln_custom_ops::paged_flash_causal_attn_decode")
68
+ def flash_attn_decode_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
69
+ return q
70
+
71
+ torch.library.define(
72
+ "rbln_custom_ops::paged_flash_causal_attn_prefill",
73
+ "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
74
+ )
75
+
76
+ @torch.library.impl("rbln_custom_ops::paged_flash_causal_attn_prefill", "cpu")
77
+ def flash_attn_prefill_cpu(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
78
+ return q
58
79
 
59
- @register_fake("rbln_custom_ops::flash_attn_prefill")
60
- def flash_attn_prefill_abstract(q, k, v, m, kcache, vcache, batch, seq, scale, partition):
61
- return q, kcache, vcache
80
+ @register_fake("rbln_custom_ops::paged_flash_causal_attn_prefill")
81
+ def flash_attn_prefill_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
82
+ return q
@@ -45,10 +45,10 @@ def register_rbln_custom_cache_update():
45
45
 
46
46
  # Update the specified portion of the cache tensor with the state tensor, using `slice_scatter`.
47
47
  # This operation modifies the cache tensor in-place directly on the device, avoiding any unnecessary transfers between host and device.
48
- updated_cache = cache.slice_scatter(state, dim=axis, start=s, end=e)
48
+ cache.slice_scatter(state, dim=axis, start=s, end=e)
49
49
 
50
- # Return the updated cache tensor.
51
- return updated_cache
50
+ # 'rbln_cache_update' is an in-place operation that isn't tracked in JIT trace, so a dummy output was added to the return value.
51
+ return torch.empty([256])
52
52
 
53
53
  # Register a "fake" implementation of the "rbln_cache_update" operation.
54
54
  # This serves as an abstract definition for the RBLN compiler to recognize the operation and generate an optimized implementation.
@@ -57,4 +57,4 @@ def register_rbln_custom_cache_update():
57
57
  # Return a tensor with the same shape as the input cache tensor.
58
58
  # This is a placeholder for the abstract implementation and does not perform any actual computation.
59
59
  # Like the actual implementation, the abstraction assumes in-place device-side updates.
60
- return torch.empty_like(cache)
60
+ return torch.empty([256])
@@ -73,7 +73,7 @@ class RBLNModelForQuestionAnswering(RBLNModel):
73
73
  if rbln_batch_size is None:
74
74
  rbln_batch_size = 1
75
75
 
76
- signature_params = inspect.signature(cls.hf_class.forward).parameters.keys()
76
+ signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
77
77
 
78
78
  if rbln_model_input_names is None:
79
79
  for tokenizer in preprocessors:
@@ -289,7 +289,7 @@ class RBLNModelForSequenceClassification(RBLNModel):
289
289
  if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
290
290
  raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
291
291
 
292
- signature_params = inspect.signature(cls.hf_class.forward).parameters.keys()
292
+ signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
293
293
 
294
294
  if rbln_model_input_names is None:
295
295
  for tokenizer in preprocessors:
@@ -362,7 +362,7 @@ class RBLNModelForMaskedLM(RBLNModel):
362
362
  if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
363
363
  raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
364
364
 
365
- signature_params = inspect.signature(cls.hf_class.forward).parameters.keys()
365
+ signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
366
366
 
367
367
  if rbln_model_input_names is None:
368
368
  for tokenizer in preprocessors:
@@ -35,16 +35,16 @@ logger = logging.get_logger(__name__)
35
35
 
36
36
 
37
37
  class BartWrapper:
38
- def __init__(self, model: nn.Module, enc_max_seq_len: int):
38
+ def __init__(self, model: nn.Module, enc_max_seq_len: int, use_attention_mask: bool):
39
39
  self.encoder = Seq2SeqEncoderWrapper(model, enc_max_seq_len)
40
- self.decoder = BartDecoderWrapper(model)
40
+ self.decoder = BartDecoderWrapper(model, use_attention_mask=use_attention_mask)
41
41
 
42
42
 
43
43
  class BartDecoderWrapper(Seq2SeqDecoderWrapper):
44
44
  def convert_to_rbln_conditional_generation(self, model: nn.Module):
45
45
  new_layers = []
46
46
  for layer in model.get_decoder().layers:
47
- self_attn = BartSelfAttention(layer.self_attn)
47
+ self_attn = BartSelfAttention(layer.self_attn, use_attention_mask=self.use_attention_mask)
48
48
  new_layers.append(BartDecoderLayer(layer, self_attn))
49
49
 
50
50
  decoder_model = BartDecoder(model.get_decoder(), new_layers)
@@ -69,7 +69,8 @@ class BartDecoder(Seq2SeqDecoder):
69
69
  self.embed_scale = getattr(self._original_mod, "embed_scale", None)
70
70
 
71
71
  def prepare_attn_mask(self, attention_mask, encoder_attention_mask, **kwargs):
72
- attention_mask = attention_mask[:, None, None, :]
72
+ if attention_mask is not None:
73
+ attention_mask = attention_mask[:, None, None, :]
73
74
  encoder_attention_mask = _prepare_4d_attention_mask(encoder_attention_mask, torch.float32, tgt_len=1)
74
75
 
75
76
  return attention_mask, encoder_attention_mask
@@ -134,7 +135,7 @@ class BartDecoderLayer(Seq2SeqDecoderLayer):
134
135
 
135
136
 
136
137
  class BartSelfAttention(Seq2SeqSelfAttention):
137
- def __post_init__(self):
138
+ def __post_init__(self, use_attention_mask: bool = True):
138
139
  self.q_proj = self._original_mod.q_proj
139
140
  self.k_proj = self._original_mod.k_proj
140
141
  self.v_proj = self._original_mod.v_proj
@@ -142,7 +143,10 @@ class BartSelfAttention(Seq2SeqSelfAttention):
142
143
  self.num_heads = self._original_mod.num_heads
143
144
  self.head_dim = self._original_mod.embed_dim // self._original_mod.num_heads
144
145
  self.scaling = self.head_dim**-0.5
145
- self.attn_decode = torch.ops.rbln_custom_ops.attn_decode
146
+ if use_attention_mask:
147
+ self.attn_decode = torch.ops.rbln_custom_ops.paged_attn_decode
148
+ else:
149
+ self.attn_decode = torch.ops.rbln_custom_ops.paged_causal_attn_decode
146
150
 
147
151
  def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
148
152
  query_states = self.q_proj(hidden_states) * self.scaling
@@ -58,7 +58,7 @@ class RBLNBartModel(RBLNModel):
58
58
  if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
59
59
  raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
60
60
 
61
- signature_params = inspect.signature(cls.hf_class.forward).parameters.keys()
61
+ signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
62
62
 
63
63
  if rbln_model_input_names is None:
64
64
  for tokenizer in preprocessors:
@@ -108,12 +108,16 @@ class RBLNBartModel(RBLNModel):
108
108
 
109
109
 
110
110
  class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
111
+ support_paged_causal_attn = True
112
+
111
113
  @classmethod
112
114
  def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
113
115
  enc_max_seq_len = (
114
116
  rbln_config.model_cfg["enc_max_seq_len"] if "enc_max_seq_len" in rbln_config.model_cfg else 1024
115
117
  )
116
- return BartWrapper(model, enc_max_seq_len=enc_max_seq_len)
118
+ use_attention_mask = rbln_config.model_cfg.get("use_attention_mask", False)
119
+
120
+ return BartWrapper(model, enc_max_seq_len=enc_max_seq_len, use_attention_mask=use_attention_mask)
117
121
 
118
122
  def __getattr__(self, __name: str) -> Any:
119
123
  def redirect(func):
@@ -56,7 +56,7 @@ class RBLNBertModel(RBLNModel):
56
56
  if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
57
57
  raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
58
58
 
59
- signature_params = inspect.signature(cls.hf_class.forward).parameters.keys()
59
+ signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
60
60
 
61
61
  if rbln_model_input_names is None:
62
62
  for tokenizer in preprocessors: