optimum-rbln 0.7.4a2__py3-none-any.whl → 0.7.4a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. optimum/rbln/__version__.py +1 -1
  2. optimum/rbln/modeling.py +8 -1
  3. optimum/rbln/modeling_base.py +0 -5
  4. optimum/rbln/ops/__init__.py +3 -7
  5. optimum/rbln/ops/attn.py +271 -207
  6. optimum/rbln/ops/flash_attn.py +161 -67
  7. optimum/rbln/ops/kv_cache_update.py +4 -40
  8. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  9. optimum/rbln/transformers/models/decoderonly/__init__.py +10 -0
  10. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +80 -94
  11. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +17 -13
  12. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +12 -21
  13. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  14. optimum/rbln/transformers/models/t5/modeling_t5.py +3 -37
  15. optimum/rbln/transformers/models/t5/t5_architecture.py +3 -4
  16. optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
  17. optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py +12 -22
  18. optimum/rbln/transformers/models/whisper/__init__.py +1 -0
  19. optimum/rbln/transformers/models/whisper/modeling_whisper.py +0 -1
  20. optimum/rbln/transformers/models/whisper/whisper_architecture.py +20 -32
  21. {optimum_rbln-0.7.4a2.dist-info → optimum_rbln-0.7.4a4.dist-info}/METADATA +1 -1
  22. {optimum_rbln-0.7.4a2.dist-info → optimum_rbln-0.7.4a4.dist-info}/RECORD +24 -24
  23. {optimum_rbln-0.7.4a2.dist-info → optimum_rbln-0.7.4a4.dist-info}/WHEEL +0 -0
  24. {optimum_rbln-0.7.4a2.dist-info → optimum_rbln-0.7.4a4.dist-info}/licenses/LICENSE +0 -0
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.7.4a2'
20
+ __version__ = version = '0.7.4a4'
21
21
  __version_tuple__ = version_tuple = (0, 7, 4)
optimum/rbln/modeling.py CHANGED
@@ -123,8 +123,15 @@ class RBLNModel(RBLNBaseModel):
123
123
  config = AutoConfig.from_pretrained(config._name_or_path, **kwargs)
124
124
 
125
125
  if hasattr(model, "can_generate") and model.can_generate():
126
+ import json
127
+
126
128
  generation_config = model.generation_config
127
- generation_config.save_pretrained(save_dir_path / subfolder)
129
+ generation_config_path = save_dir_path / subfolder / "generation_config.json"
130
+
131
+ generation_config.save_pretrained(generation_config_path.parent)
132
+ local_config = json.loads(generation_config_path.read_text(encoding="utf-8"))
133
+ local_config["transformers_version"] = generation_config.transformers_version
134
+ generation_config_path.write_text(json.dumps(local_config, indent=2) + "\n", encoding="utf-8")
128
135
 
129
136
  if not isinstance(config, PretrainedConfig): # diffusers config
130
137
  config = PretrainedConfig(**config)
@@ -481,11 +481,6 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
481
481
  # First copy everything to a temporary directory
482
482
  shutil.copytree(real_save_dir, tmp_dir)
483
483
 
484
- # Save configs to the temporary directory
485
- self.config.save_pretrained(tmp_dir)
486
- if self.generation_config is not None:
487
- self.generation_config.save_pretrained(tmp_dir)
488
-
489
484
  # If everything succeeded, atomically replace the target directory
490
485
  if os.path.exists(save_directory_path):
491
486
  shutil.rmtree(save_directory_path)
@@ -12,11 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .attn import (
16
- register_rbln_custom_paged_add_softmax_attention,
17
- register_rbln_custom_paged_attention,
18
- register_rbln_custom_paged_causal_attention,
19
- )
20
- from .flash_attn import register_rbln_custom_paged_flash_attention, register_rbln_custom_paged_flash_causal_attention
21
- from .kv_cache_update import register_rbln_custom_cache_update
15
+ from .attn import *
16
+ from .flash_attn import *
17
+ from .kv_cache_update import *
22
18
  from .linear import linear
optimum/rbln/ops/attn.py CHANGED
@@ -12,212 +12,276 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from functools import lru_cache
16
15
 
17
16
  import torch
18
- from packaging import version
19
-
20
-
21
- if version.parse(torch.__version__) > version.parse("2.4.0"):
22
- register_fake = torch.library.register_fake
23
- else:
24
- register_fake = torch.library.impl_abstract
25
-
26
-
27
- @lru_cache
28
- def register_rbln_custom_paged_attention():
29
- torch.library.define(
30
- "rbln_custom_ops::paged_attn_decode",
31
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor",
32
- )
33
-
34
- @torch.library.impl("rbln_custom_ops::paged_attn_decode", "cpu")
35
- def attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size):
36
- """Defines the computation pattern for fused attention with KV cache updates.
37
-
38
- IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
39
- a single optimized NPU operation. It is NOT meant for CPU execution.
40
-
41
- Pattern components that compiler fuses into a single op:
42
- 1. KV cache updates with new key/value states
43
- 2. Scaled dot-product attention computation
44
- 3. Masked softmax operation
45
- 4. Final attention output computation
46
-
47
- Expected tensor shapes:
48
- - q: [batch=1, n_heads, n_groups, 1, head_dim] - Query states for single token
49
- - k: [batch=1, n_heads, 1, 1, head_dim] - Key states for current input
50
- - v: [batch=1, n_heads, 1, 1, head_dim] - Value states for current input
51
- - mask: [batch=1, n_heads, 1, 1, max_seq_len] - Attention mask
52
- - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
53
- - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
54
- - seq: [1, 1] - Current sequence position
55
- - scale: [] - Attention scale factor
56
- - block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
57
- - block_size: [] - Number of tokens per block
58
-
59
- Returns:
60
- Tensor: attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
61
- """
62
- return q
63
-
64
- @register_fake("rbln_custom_ops::paged_attn_decode")
65
- def attn_decode_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size):
66
- return q
67
-
68
- torch.library.define(
69
- "rbln_custom_ops::paged_attn_prefill",
70
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor",
71
- )
72
-
73
- @torch.library.impl("rbln_custom_ops::paged_attn_prefill", "cpu")
74
- def attn_prefill_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size):
75
- """Defines the computation pattern for prefill phase attention with KV cache updates.
76
-
77
- IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
78
- a single optimized NPU operation. It is NOT meant for CPU execution.
79
-
80
- Key differences from decode pattern:
81
- - Handles prefill phase with multiple input tokens
82
- - Takes explicit batch index for continuous batching
83
-
84
- Expected tensor shapes:
85
- - q: [batch=1, n_heads, n_groups, seq_len, head_dim] - Query states for multiple tokens
86
- - k: [batch=1, n_heads, 1, seq_len, head_dim] - Key states for current input
87
- - v: [batch=1, n_heads, 1, seq_len, head_dim] - Value states for current input
88
- - mask: [batch=1, 1, 1, seq_len, max_seq_len] - Attention mask
89
- - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
90
- - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
91
- - seq: [1, 1] - Starting sequence position
92
- - scale: [] - Attention scale factor
93
- - block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
94
- - block_size: [] - Number of tokens per block
95
-
96
- Returns:
97
- Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
98
- """
99
- return q
100
-
101
- @register_fake("rbln_custom_ops::paged_attn_prefill")
102
- def attn_prefill_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size):
103
- return q
104
-
105
-
106
- @lru_cache
107
- def register_rbln_custom_paged_causal_attention():
108
- torch.library.define(
109
- "rbln_custom_ops::paged_causal_attn_decode",
110
- "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor",
111
- )
112
-
113
- @torch.library.impl("rbln_custom_ops::paged_causal_attn_decode", "cpu")
114
- def attn_decode_cpu(q, k, v, kcache, vcache, seq, scale, block_table, block_size):
115
- """Defines the computation pattern for fused attention with KV cache updates.
116
-
117
- IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
118
- a single optimized NPU operation. It is NOT meant for CPU execution.
119
-
120
- Pattern components that compiler fuses into a single op:
121
- 1. KV cache updates with new key/value states
122
- 2. Scaled dot-product attention computation
123
- 3. Causal masked softmax operation
124
- 4. Final attention output computation
125
-
126
- Expected tensor shapes:
127
- - q: [batch=1, n_heads, n_groups, 1, head_dim] - Query states for single token
128
- - k: [batch=1, n_heads, 1, 1, head_dim] - Key states for current input
129
- - v: [batch=1, n_heads, 1, 1, head_dim] - Value states for current input
130
- - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
131
- - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
132
- - seq: [1, 1] - Starting sequence position
133
- - scale: [] - Attention scale factor
134
- - block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
135
- - block_size: [] - Number of tokens per block
136
-
137
- Returns:
138
- Tensor: attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
139
- """
140
- return q
141
-
142
- @register_fake("rbln_custom_ops::paged_causal_attn_decode")
143
- def attn_decode_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size):
144
- return q
145
-
146
- torch.library.define(
147
- "rbln_custom_ops::paged_causal_attn_prefill",
148
- "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor",
149
- )
150
-
151
- @torch.library.impl("rbln_custom_ops::paged_causal_attn_prefill", "cpu")
152
- def attn_prefill_cpu(q, k, v, kcache, vcache, seq, scale, block_table, block_size):
153
- """Defines the computation pattern for prefill phase attention with KV cache updates.
154
-
155
- IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
156
- a single optimized NPU operation. It is NOT meant for CPU execution.
157
-
158
- Key differences from decode pattern:
159
- - Handles prefill phase with multiple input tokens
160
- - Takes explicit batch index for continuous batching
161
-
162
- Expected tensor shapes:
163
- - q: [batch=1, n_heads, n_groups, seq_len, head_dim] - Query states for multiple tokens
164
- - k: [batch=1, n_heads, 1, seq_len, head_dim] - Key states for current input
165
- - v: [batch=1, n_heads, 1, seq_len, head_dim] - Value states for current input
166
- - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
167
- - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
168
- - batch: [1] - Batch index for cache access
169
- - seq: [1, 1] - Starting sequence position
170
- - scale: [] - Attention scale factor
171
- - block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
172
- - block_size: [] - Number of tokens per block
173
-
174
- Returns:
175
- Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
176
- """
177
- return q
178
-
179
- @register_fake("rbln_custom_ops::paged_causal_attn_prefill")
180
- def attn_prefill_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size):
181
- return q
182
-
183
-
184
- @lru_cache
185
- def register_rbln_custom_paged_add_softmax_attention():
186
- torch.library.define(
187
- "rbln_custom_ops::paged_add_softmax_attn_decode",
188
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor",
189
- )
190
-
191
- @torch.library.impl("rbln_custom_ops::paged_add_softmax_attn_decode", "cpu")
192
- def paged_add_softmax_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size):
193
- """Defines the computation pattern for fused attention with KV cache updates.
194
-
195
- IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
196
- a single optimized NPU operation. It is NOT meant for CPU execution.
197
-
198
- Pattern components that compiler fuses into a single op:
199
- 1. KV cache updates with new key/value states
200
- 2. Scaled dot-product attention computation
201
- 3. add-softmax operation
202
- 4. Final attention output computation
203
-
204
- Expected tensor shapes:
205
- - q: [batch=1, n_heads, n_groups, 1, head_dim] - Query states for single token
206
- - k: [batch=1, n_heads, 1, 1, head_dim] - Key states for current input
207
- - v: [batch=1, n_heads, 1, 1, head_dim] - Value states for current input
208
- - mask: [batch=1, n_heads, 1, 1, max_seq_len] - Attention mask
209
- - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
210
- - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
211
- - seq: [1] - Current sequence position
212
- - scale: [] - Attention scale factor
213
- - block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
214
- - block_size: [] - Number of tokens per block
215
-
216
- Returns:
217
- Tensor: attn_output: [batch=1, n_heads, 1, 1, head_dim] - Attention output
218
- """
219
- return q
220
-
221
- @register_fake("rbln_custom_ops::paged_add_softmax_attn_decode")
222
- def paged_add_softmax_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, partition, block_table, block_size):
223
- return q
17
+ from torch import Tensor
18
+
19
+
20
+ @torch.library.custom_op(
21
+ "rbln_custom_ops::paged_attn_decode",
22
+ mutates_args=(["kcache", "vcache"]),
23
+ )
24
+ def paged_attn_decode(
25
+ q: Tensor,
26
+ k: Tensor,
27
+ v: Tensor,
28
+ mask: Tensor,
29
+ kcache: Tensor,
30
+ vcache: Tensor,
31
+ seq: Tensor,
32
+ scale: Tensor,
33
+ block_table: Tensor,
34
+ block_size: int,
35
+ ) -> Tensor:
36
+ return torch.empty_like(q)
37
+
38
+
39
+ @paged_attn_decode.register_fake
40
+ def paged_attn_decode_fake(
41
+ q: Tensor,
42
+ k: Tensor,
43
+ v: Tensor,
44
+ mask: Tensor,
45
+ kcache: Tensor,
46
+ vcache: Tensor,
47
+ seq: Tensor,
48
+ scale: Tensor,
49
+ block_table: Tensor,
50
+ block_size: int,
51
+ ) -> Tensor:
52
+ return torch.empty_like(q)
53
+
54
+
55
+ @torch.library.custom_op(
56
+ "rbln_custom_ops::paged_attn_prefill",
57
+ mutates_args=(["kcache", "vcache"]),
58
+ )
59
+ def paged_attn_prefill(
60
+ q: Tensor,
61
+ k: Tensor,
62
+ v: Tensor,
63
+ mask: Tensor,
64
+ kcache: Tensor,
65
+ vcache: Tensor,
66
+ seq: Tensor,
67
+ scale: Tensor,
68
+ block_table: Tensor,
69
+ block_size: int,
70
+ ) -> Tensor:
71
+ """Defines the computation pattern for prefill phase attention with KV cache updates.
72
+
73
+ IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
74
+ a single optimized NPU operation. It is NOT meant for CPU execution.
75
+
76
+ Key differences from decode pattern:
77
+ - Handles prefill phase with multiple input tokens
78
+ - Takes explicit batch index for continuous batching
79
+
80
+ Expected tensor shapes:
81
+ - q: [batch=1, n_heads, n_groups, seq_len, head_dim] - Query states for multiple tokens
82
+ - k: [batch=1, n_heads, 1, seq_len, head_dim] - Key states for current input
83
+ - v: [batch=1, n_heads, 1, seq_len, head_dim] - Value states for current input
84
+ - mask: [batch=1, 1, 1, seq_len, max_seq_len] - Attention mask
85
+ - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
86
+ - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
87
+ - seq: [1, 1] - Starting sequence position
88
+ - scale: [] - Attention scale factor
89
+ - block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
90
+ - block_size: [] - Number of tokens per block
91
+
92
+ Returns:
93
+ Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
94
+ """
95
+ return torch.empty_like(q)
96
+
97
+
98
+ @paged_attn_prefill.register_fake
99
+ def paged_attn_prefill_fake(
100
+ q: Tensor,
101
+ k: Tensor,
102
+ v: Tensor,
103
+ mask: Tensor,
104
+ kcache: Tensor,
105
+ vcache: Tensor,
106
+ seq: Tensor,
107
+ scale: Tensor,
108
+ block_table: Tensor,
109
+ block_size: int,
110
+ ) -> Tensor:
111
+ return torch.empty_like(q)
112
+
113
+
114
+ @torch.library.custom_op(
115
+ "rbln_custom_ops::paged_causal_attn_decode",
116
+ mutates_args=(["kcache", "vcache"]),
117
+ )
118
+ def paged_causal_attn_decode(
119
+ q: Tensor,
120
+ k: Tensor,
121
+ v: Tensor,
122
+ kcache: Tensor,
123
+ vcache: Tensor,
124
+ seq: Tensor,
125
+ scale: Tensor,
126
+ block_table: Tensor,
127
+ block_size: int,
128
+ ) -> Tensor:
129
+ """Defines the computation pattern for fused attention with KV cache updates.
130
+
131
+ IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
132
+ a single optimized NPU operation. It is NOT meant for CPU execution.
133
+
134
+ Pattern components that compiler fuses into a single op:
135
+ 1. KV cache updates with new key/value states
136
+ 2. Scaled dot-product attention computation
137
+ 3. Causal masked softmax operation
138
+ 4. Final attention output computation
139
+
140
+ Expected tensor shapes:
141
+ - q: [batch=1, n_heads, n_groups, 1, head_dim] - Query states for single token
142
+ - k: [batch=1, n_heads, 1, 1, head_dim] - Key states for current input
143
+ - v: [batch=1, n_heads, 1, 1, head_dim] - Value states for current input
144
+ - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
145
+ - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
146
+ - seq: [1, 1] - Starting sequence position
147
+ - scale: [] - Attention scale factor
148
+ - block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
149
+ - block_size: [] - Number of tokens per block
150
+
151
+ Returns:
152
+ Tensor: attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
153
+ """
154
+ return torch.empty_like(q)
155
+
156
+
157
+ @paged_causal_attn_decode.register_fake
158
+ def paged_causal_attn_decode_fake(
159
+ q: Tensor,
160
+ k: Tensor,
161
+ v: Tensor,
162
+ kcache: Tensor,
163
+ vcache: Tensor,
164
+ seq: Tensor,
165
+ scale: Tensor,
166
+ block_table: Tensor,
167
+ block_size: int,
168
+ ) -> Tensor:
169
+ return torch.empty_like(q)
170
+
171
+
172
+ @torch.library.custom_op(
173
+ "rbln_custom_ops::paged_causal_attn_prefill",
174
+ mutates_args=(["kcache", "vcache"]),
175
+ )
176
+ def paged_causal_attn_prefill(
177
+ q: Tensor,
178
+ k: Tensor,
179
+ v: Tensor,
180
+ kcache: Tensor,
181
+ vcache: Tensor,
182
+ seq: Tensor,
183
+ scale: Tensor,
184
+ block_table: Tensor,
185
+ block_size: int,
186
+ ) -> Tensor:
187
+ """Defines the computation pattern for prefill phase attention with KV cache updates.
188
+
189
+ IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
190
+ a single optimized NPU operation. It is NOT meant for CPU execution.
191
+
192
+ Key differences from decode pattern:
193
+ - Handles prefill phase with multiple input tokens
194
+ - Takes explicit batch index for continuous batching
195
+
196
+ Expected tensor shapes:
197
+ - q: [batch=1, n_heads, n_groups, seq_len, head_dim] - Query states for multiple tokens
198
+ - k: [batch=1, n_heads, 1, seq_len, head_dim] - Key states for current input
199
+ - v: [batch=1, n_heads, 1, seq_len, head_dim] - Value states for current input
200
+ - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
201
+ - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
202
+ - batch: [1] - Batch index for cache access
203
+ - seq: [1, 1] - Starting sequence position
204
+ - scale: [] - Attention scale factor
205
+ - block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
206
+ - block_size: [] - Number of tokens per block
207
+
208
+ Returns:
209
+ Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
210
+ """
211
+ return torch.empty_like(q)
212
+
213
+
214
+ @paged_causal_attn_prefill.register_fake
215
+ def paged_causal_attn_prefill_fake(
216
+ q: Tensor,
217
+ k: Tensor,
218
+ v: Tensor,
219
+ kcache: Tensor,
220
+ vcache: Tensor,
221
+ seq: Tensor,
222
+ scale: Tensor,
223
+ block_table: Tensor,
224
+ block_size: int,
225
+ ) -> Tensor:
226
+ return torch.empty_like(q)
227
+
228
+
229
+ @torch.library.custom_op(
230
+ "rbln_custom_ops::paged_add_softmax_attn_decode",
231
+ mutates_args=(["kcache", "vcache"]),
232
+ )
233
+ def paged_add_softmax_attn_decode(
234
+ q: Tensor,
235
+ k: Tensor,
236
+ v: Tensor,
237
+ mask: Tensor,
238
+ kcache: Tensor,
239
+ vcache: Tensor,
240
+ seq: Tensor,
241
+ scale: Tensor,
242
+ block_table: Tensor,
243
+ block_size: int,
244
+ ) -> Tensor:
245
+ """Defines the computation pattern for fused attention with KV cache updates.
246
+
247
+ IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
248
+ a single optimized NPU operation. It is NOT meant for CPU execution.
249
+
250
+ Pattern components that compiler fuses into a single op:
251
+ 1. KV cache updates with new key/value states
252
+ 2. Scaled dot-product attention computation
253
+ 3. add-softmax operation
254
+ 4. Final attention output computation
255
+
256
+ Expected tensor shapes:
257
+ - q: [batch=1, n_heads, n_groups, 1, head_dim] - Query states for single token
258
+ - k: [batch=1, n_heads, 1, 1, head_dim] - Key states for current input
259
+ - v: [batch=1, n_heads, 1, 1, head_dim] - Value states for current input
260
+ - mask: [batch=1, n_heads, 1, 1, max_seq_len] - Attention mask
261
+ - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
262
+ - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
263
+ - seq: [1] - Current sequence position
264
+ - scale: [] - Attention scale factor
265
+ - block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
266
+ - block_size: [] - Number of tokens per block
267
+
268
+ Returns:
269
+ Tensor: attn_output: [batch=1, n_heads, 1, 1, head_dim] - Attention output
270
+ """
271
+ return torch.empty_like(q)
272
+
273
+
274
+ @paged_add_softmax_attn_decode.register_fake
275
+ def paged_add_softmax_attn_decode_fake(
276
+ q: Tensor,
277
+ k: Tensor,
278
+ v: Tensor,
279
+ mask: Tensor,
280
+ kcache: Tensor,
281
+ vcache: Tensor,
282
+ seq: Tensor,
283
+ scale: Tensor,
284
+ block_table: Tensor,
285
+ block_size: int,
286
+ ) -> Tensor:
287
+ return torch.empty_like(q)