optimum-rbln 0.7.2rc2__py3-none-any.whl → 0.7.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +8 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/diffusers/__init__.py +8 -0
- optimum/rbln/diffusers/modeling_diffusers.py +103 -117
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +11 -3
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +15 -8
- optimum/rbln/diffusers/pipelines/__init__.py +8 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +7 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +25 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +107 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +25 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +3 -0
- optimum/rbln/modeling.py +4 -1
- optimum/rbln/modeling_base.py +16 -3
- optimum/rbln/ops/__init__.py +6 -2
- optimum/rbln/ops/attn.py +94 -85
- optimum/rbln/ops/flash_attn.py +46 -25
- optimum/rbln/ops/kv_cache_update.py +4 -4
- optimum/rbln/transformers/modeling_generic.py +3 -3
- optimum/rbln/transformers/models/bart/bart_architecture.py +10 -6
- optimum/rbln/transformers/models/bart/modeling_bart.py +6 -2
- optimum/rbln/transformers/models/bert/modeling_bert.py +1 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +264 -133
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +276 -29
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +11 -4
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +11 -4
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +5 -3
- optimum/rbln/transformers/models/midm/midm_architecture.py +5 -3
- optimum/rbln/transformers/models/phi/phi_architecture.py +9 -7
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +50 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +60 -36
- optimum/rbln/transformers/models/t5/modeling_t5.py +3 -1
- optimum/rbln/transformers/models/t5/t5_architecture.py +65 -3
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +26 -36
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +1 -14
- optimum/rbln/utils/import_utils.py +7 -0
- {optimum_rbln-0.7.2rc2.dist-info → optimum_rbln-0.7.3.dist-info}/METADATA +1 -1
- {optimum_rbln-0.7.2rc2.dist-info → optimum_rbln-0.7.3.dist-info}/RECORD +40 -38
- {optimum_rbln-0.7.2rc2.dist-info → optimum_rbln-0.7.3.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.2rc2.dist-info → optimum_rbln-0.7.3.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/modeling_base.py
CHANGED
@@ -295,6 +295,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
295
295
|
):
|
296
296
|
if isinstance(model_save_dir, str):
|
297
297
|
model_save_dir = Path(model_save_dir)
|
298
|
+
|
298
299
|
# FIXME:: Should we convert it?
|
299
300
|
compiled_model_names = [cfg.compiled_model_name for cfg in rbln_config.compile_cfgs]
|
300
301
|
rbln_compiled_models = [rbln_compiled_models[cm_name] for cm_name in compiled_model_names]
|
@@ -389,8 +390,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
389
390
|
return rbln_config
|
390
391
|
|
391
392
|
@classmethod
|
392
|
-
|
393
|
-
def hf_class(cls):
|
393
|
+
def get_hf_class(cls):
|
394
394
|
"""
|
395
395
|
Lazily loads and caches the corresponding Hugging Face model class.
|
396
396
|
Removes 'RBLN' prefix from the class name to get the original class name
|
@@ -416,7 +416,20 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
416
416
|
return self.forward(*args, **kwargs)
|
417
417
|
|
418
418
|
def __repr__(self):
|
419
|
-
|
419
|
+
has_submodules = len(self.rbln_submodules) > 0
|
420
|
+
repr_str: str = f"<{self.__class__.__name__}>\n"
|
421
|
+
repr_str += f"- Total {len(self.model)} Runtimes"
|
422
|
+
repr_str += f" and {len(self.rbln_submodules)} Submodules\n" if has_submodules else "\n"
|
423
|
+
repr_str += "[Runtimes]\n"
|
424
|
+
repr_str += "\n".join([repr(model) for model in self.model])
|
425
|
+
repr_str += "\n"
|
426
|
+
|
427
|
+
if has_submodules > 0:
|
428
|
+
for i, submodule in enumerate(self.rbln_submodules):
|
429
|
+
repr_str += f"[Submodules {i} : {self._rbln_submodules[i]['name']}]\n"
|
430
|
+
repr_str += repr(submodule) + "\n"
|
431
|
+
|
432
|
+
return repr_str
|
420
433
|
|
421
434
|
def __post_init__(self, **kwargs):
|
422
435
|
pass
|
optimum/rbln/ops/__init__.py
CHANGED
@@ -12,6 +12,10 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from .attn import
|
16
|
-
|
15
|
+
from .attn import (
|
16
|
+
register_rbln_custom_add_softmax_attention,
|
17
|
+
register_rbln_custom_paged_attention,
|
18
|
+
register_rbln_custom_paged_causal_attention,
|
19
|
+
)
|
20
|
+
from .flash_attn import register_rbln_custom_paged_flash_attention, register_rbln_custom_paged_flash_causal_attention
|
17
21
|
from .kv_cache_update import register_rbln_custom_cache_update
|
optimum/rbln/ops/attn.py
CHANGED
@@ -25,14 +25,14 @@ else:
|
|
25
25
|
|
26
26
|
|
27
27
|
@lru_cache
|
28
|
-
def
|
28
|
+
def register_rbln_custom_paged_attention():
|
29
29
|
torch.library.define(
|
30
|
-
"rbln_custom_ops::
|
31
|
-
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor
|
30
|
+
"rbln_custom_ops::paged_attn_decode",
|
31
|
+
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor",
|
32
32
|
)
|
33
33
|
|
34
|
-
@torch.library.impl("rbln_custom_ops::
|
35
|
-
def attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale):
|
34
|
+
@torch.library.impl("rbln_custom_ops::paged_attn_decode", "cpu")
|
35
|
+
def attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size):
|
36
36
|
"""Defines the computation pattern for fused attention with KV cache updates.
|
37
37
|
|
38
38
|
IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
|
@@ -51,36 +51,27 @@ def register_rbln_custom_attention():
|
|
51
51
|
- mask: [batch=1, n_heads, 1, 1, max_seq_len] - Attention mask
|
52
52
|
- kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
|
53
53
|
- vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
|
54
|
-
- seq: [1] - Current sequence position
|
54
|
+
- seq: [1, 1] - Current sequence position
|
55
55
|
- scale: [] - Attention scale factor
|
56
|
+
- block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
|
57
|
+
- block_size: [] - Number of tokens per block
|
56
58
|
|
57
59
|
Returns:
|
58
|
-
|
59
|
-
- attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
|
60
|
-
- kcache: Same shape as input kcache, batch=1 - Placeholder for compiler
|
61
|
-
- vcache: Same shape as input vcache, batch=1 - Placeholder for compiler
|
60
|
+
Tensor: attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
|
62
61
|
"""
|
63
|
-
return
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
@register_fake("rbln_custom_ops::attn_decode")
|
70
|
-
def attn_decode_abstract(q, k, v, m, kcache, vcache, seq, partition):
|
71
|
-
return (
|
72
|
-
q,
|
73
|
-
torch.empty(*kcache.shape, device=kcache.device),
|
74
|
-
torch.empty(*vcache.shape, device=vcache.device),
|
75
|
-
)
|
62
|
+
return q
|
63
|
+
|
64
|
+
@register_fake("rbln_custom_ops::paged_attn_decode")
|
65
|
+
def attn_decode_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size):
|
66
|
+
return q
|
76
67
|
|
77
68
|
torch.library.define(
|
78
|
-
"rbln_custom_ops::
|
79
|
-
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e) -> Tensor
|
69
|
+
"rbln_custom_ops::paged_attn_prefill",
|
70
|
+
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor",
|
80
71
|
)
|
81
72
|
|
82
|
-
@torch.library.impl("rbln_custom_ops::
|
83
|
-
def attn_prefill_cpu(q, k, v, mask, kcache, vcache,
|
73
|
+
@torch.library.impl("rbln_custom_ops::paged_attn_prefill", "cpu")
|
74
|
+
def attn_prefill_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size):
|
84
75
|
"""Defines the computation pattern for prefill phase attention with KV cache updates.
|
85
76
|
|
86
77
|
IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
|
@@ -97,32 +88,30 @@ def register_rbln_custom_attention():
|
|
97
88
|
- mask: [batch=1, 1, 1, seq_len, max_seq_len] - Attention mask
|
98
89
|
- kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
|
99
90
|
- vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
|
100
|
-
-
|
101
|
-
- seq: [1] - Starting sequence position
|
91
|
+
- seq: [1, 1] - Starting sequence position
|
102
92
|
- scale: [] - Attention scale factor
|
93
|
+
- block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
|
94
|
+
- block_size: [] - Number of tokens per block
|
103
95
|
|
104
96
|
Returns:
|
105
|
-
|
106
|
-
- attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
|
107
|
-
- empty_kcache: Same shape as input kcache - Placeholder for compiler
|
108
|
-
- empty_vcache: Same shape as input vcache - Placeholder for compiler
|
97
|
+
Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
|
109
98
|
"""
|
110
|
-
return q
|
99
|
+
return q
|
111
100
|
|
112
|
-
@register_fake("rbln_custom_ops::
|
113
|
-
def attn_prefill_abstract(q, k, v, m, kcache, vcache,
|
114
|
-
return q
|
101
|
+
@register_fake("rbln_custom_ops::paged_attn_prefill")
|
102
|
+
def attn_prefill_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size):
|
103
|
+
return q
|
115
104
|
|
116
105
|
|
117
106
|
@lru_cache
|
118
|
-
def
|
107
|
+
def register_rbln_custom_paged_causal_attention():
|
119
108
|
torch.library.define(
|
120
|
-
"rbln_custom_ops::
|
121
|
-
"(Tensor x, Tensor y, Tensor z, Tensor
|
109
|
+
"rbln_custom_ops::paged_causal_attn_decode",
|
110
|
+
"(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor",
|
122
111
|
)
|
123
112
|
|
124
|
-
@torch.library.impl("rbln_custom_ops::
|
125
|
-
def
|
113
|
+
@torch.library.impl("rbln_custom_ops::paged_causal_attn_decode", "cpu")
|
114
|
+
def attn_decode_cpu(q, k, v, kcache, vcache, seq, scale, block_table, block_size):
|
126
115
|
"""Defines the computation pattern for fused attention with KV cache updates.
|
127
116
|
|
128
117
|
IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
|
@@ -131,46 +120,36 @@ def register_rbln_custom_attention_add_softmax():
|
|
131
120
|
Pattern components that compiler fuses into a single op:
|
132
121
|
1. KV cache updates with new key/value states
|
133
122
|
2. Scaled dot-product attention computation
|
134
|
-
3.
|
123
|
+
3. Causal masked softmax operation
|
135
124
|
4. Final attention output computation
|
136
125
|
|
137
126
|
Expected tensor shapes:
|
138
127
|
- q: [batch=1, n_heads, n_groups, 1, head_dim] - Query states for single token
|
139
128
|
- k: [batch=1, n_heads, 1, 1, head_dim] - Key states for current input
|
140
129
|
- v: [batch=1, n_heads, 1, 1, head_dim] - Value states for current input
|
141
|
-
- mask: [batch=1, n_heads, 1, 1, max_seq_len] - Attention mask
|
142
130
|
- kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
|
143
131
|
- vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
|
144
|
-
- seq: [1] -
|
132
|
+
- seq: [1, 1] - Starting sequence position
|
145
133
|
- scale: [] - Attention scale factor
|
134
|
+
- block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
|
135
|
+
- block_size: [] - Number of tokens per block
|
146
136
|
|
147
137
|
Returns:
|
148
|
-
|
149
|
-
- attn_output: [batch=1, n_heads, 1, 1, head_dim] - Attention output
|
150
|
-
- kcache: Same shape as input kcache, batch=1 - Placeholder for compiler
|
151
|
-
- vcache: Same shape as input vcache, batch=1 - Placeholder for compiler
|
138
|
+
Tensor: attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
|
152
139
|
"""
|
153
|
-
return
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
@register_fake("rbln_custom_ops::attn_decode_add_softmax")
|
160
|
-
def attn_decode_add_softmax_abstract(q, k, v, m, kcache, vcache, seq, partition):
|
161
|
-
return (
|
162
|
-
q,
|
163
|
-
torch.empty(*kcache.shape, device=kcache.device),
|
164
|
-
torch.empty(*vcache.shape, device=vcache.device),
|
165
|
-
)
|
140
|
+
return q
|
141
|
+
|
142
|
+
@register_fake("rbln_custom_ops::paged_causal_attn_decode")
|
143
|
+
def attn_decode_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size):
|
144
|
+
return q
|
166
145
|
|
167
146
|
torch.library.define(
|
168
|
-
"rbln_custom_ops::
|
169
|
-
"(Tensor x, Tensor y, Tensor z, Tensor
|
147
|
+
"rbln_custom_ops::paged_causal_attn_prefill",
|
148
|
+
"(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor",
|
170
149
|
)
|
171
150
|
|
172
|
-
@torch.library.impl("rbln_custom_ops::
|
173
|
-
def
|
151
|
+
@torch.library.impl("rbln_custom_ops::paged_causal_attn_prefill", "cpu")
|
152
|
+
def attn_prefill_cpu(q, k, v, kcache, vcache, seq, scale, block_table, block_size):
|
174
153
|
"""Defines the computation pattern for prefill phase attention with KV cache updates.
|
175
154
|
|
176
155
|
IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
|
@@ -184,29 +163,59 @@ def register_rbln_custom_attention_add_softmax():
|
|
184
163
|
- q: [batch=1, n_heads, n_groups, seq_len, head_dim] - Query states for multiple tokens
|
185
164
|
- k: [batch=1, n_heads, 1, seq_len, head_dim] - Key states for current input
|
186
165
|
- v: [batch=1, n_heads, 1, seq_len, head_dim] - Value states for current input
|
187
|
-
- mask: [batch=1, 1, 1, seq_len, max_seq_len] - Attention mask
|
188
166
|
- kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
|
189
167
|
- vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
|
190
168
|
- batch: [1] - Batch index for cache access
|
191
|
-
- seq: [1] - Starting sequence position
|
169
|
+
- seq: [1, 1] - Starting sequence position
|
192
170
|
- scale: [] - Attention scale factor
|
171
|
+
- block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
|
172
|
+
- block_size: [] - Number of tokens per block
|
193
173
|
|
194
174
|
Returns:
|
195
|
-
|
196
|
-
- attn_output: [batch=1, n_heads, seq_len, 1, head_dim] - Attention output
|
197
|
-
- empty_kcache: Same shape as input kcache - Placeholder for compiler
|
198
|
-
- empty_vcache: Same shape as input vcache - Placeholder for compiler
|
175
|
+
Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
|
199
176
|
"""
|
200
|
-
return
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
177
|
+
return q
|
178
|
+
|
179
|
+
@register_fake("rbln_custom_ops::paged_causal_attn_prefill")
|
180
|
+
def attn_prefill_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size):
|
181
|
+
return q
|
182
|
+
|
183
|
+
|
184
|
+
@lru_cache
|
185
|
+
def register_rbln_custom_add_softmax_attention():
|
186
|
+
torch.library.define(
|
187
|
+
"rbln_custom_ops::add_softmax_attn_decode",
|
188
|
+
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor",
|
189
|
+
)
|
190
|
+
|
191
|
+
@torch.library.impl("rbln_custom_ops::add_softmax_attn_decode", "cpu")
|
192
|
+
def add_softmax_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale):
|
193
|
+
"""Defines the computation pattern for fused attention with KV cache updates.
|
194
|
+
|
195
|
+
IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
|
196
|
+
a single optimized NPU operation. It is NOT meant for CPU execution.
|
197
|
+
|
198
|
+
Pattern components that compiler fuses into a single op:
|
199
|
+
1. KV cache updates with new key/value states
|
200
|
+
2. Scaled dot-product attention computation
|
201
|
+
3. add-softmax operation
|
202
|
+
4. Final attention output computation
|
203
|
+
|
204
|
+
Expected tensor shapes:
|
205
|
+
- q: [batch=1, n_heads, n_groups, 1, head_dim] - Query states for single token
|
206
|
+
- k: [batch=1, n_heads, 1, 1, head_dim] - Key states for current input
|
207
|
+
- v: [batch=1, n_heads, 1, 1, head_dim] - Value states for current input
|
208
|
+
- mask: [batch=1, n_heads, 1, 1, max_seq_len] - Attention mask
|
209
|
+
- kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
|
210
|
+
- vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
|
211
|
+
- seq: [1] - Current sequence position
|
212
|
+
- scale: [] - Attention scale factor
|
213
|
+
|
214
|
+
Returns:
|
215
|
+
Tensor: attn_output: [batch=1, n_heads, 1, 1, head_dim] - Attention output
|
216
|
+
"""
|
217
|
+
return q
|
218
|
+
|
219
|
+
@register_fake("rbln_custom_ops::add_softmax_attn_decode")
|
220
|
+
def add_softmax_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, partition):
|
221
|
+
return q
|
optimum/rbln/ops/flash_attn.py
CHANGED
@@ -25,37 +25,58 @@ else:
|
|
25
25
|
|
26
26
|
|
27
27
|
@lru_cache
|
28
|
-
def
|
28
|
+
def register_rbln_custom_paged_flash_attention():
|
29
29
|
torch.library.define(
|
30
|
-
"rbln_custom_ops::
|
31
|
-
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, int
|
30
|
+
"rbln_custom_ops::paged_flash_attn_decode",
|
31
|
+
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
|
32
32
|
)
|
33
33
|
|
34
|
-
@torch.library.impl("rbln_custom_ops::
|
35
|
-
def flash_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale, partition):
|
36
|
-
return
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
@register_fake("rbln_custom_ops::flash_attn_decode")
|
43
|
-
def flash_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, scale, partition):
|
44
|
-
return (
|
45
|
-
q,
|
46
|
-
torch.empty(*kcache.shape, device=kcache.device),
|
47
|
-
torch.empty(*vcache.shape, device=vcache.device),
|
48
|
-
)
|
34
|
+
@torch.library.impl("rbln_custom_ops::paged_flash_attn_decode", "cpu")
|
35
|
+
def flash_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size, partition):
|
36
|
+
return q
|
37
|
+
|
38
|
+
@register_fake("rbln_custom_ops::paged_flash_attn_decode")
|
39
|
+
def flash_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size, partition):
|
40
|
+
return q
|
49
41
|
|
50
42
|
torch.library.define(
|
51
|
-
"rbln_custom_ops::
|
52
|
-
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor
|
43
|
+
"rbln_custom_ops::paged_flash_attn_prefill",
|
44
|
+
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
|
53
45
|
)
|
54
46
|
|
55
47
|
@torch.library.impl("rbln_custom_ops::flash_attn_prefill", "cpu")
|
56
|
-
def flash_attn_prefill_cpu(q, k, v, mask, kcache, vcache,
|
57
|
-
return q
|
48
|
+
def flash_attn_prefill_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size, partition):
|
49
|
+
return q
|
50
|
+
|
51
|
+
@register_fake("rbln_custom_ops::paged_flash_attn_prefill")
|
52
|
+
def flash_attn_prefill_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size, partition):
|
53
|
+
return q
|
54
|
+
|
55
|
+
|
56
|
+
@lru_cache
|
57
|
+
def register_rbln_custom_paged_flash_causal_attention():
|
58
|
+
torch.library.define(
|
59
|
+
"rbln_custom_ops::paged_flash_causal_attn_decode",
|
60
|
+
"(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
|
61
|
+
)
|
62
|
+
|
63
|
+
@torch.library.impl("rbln_custom_ops::paged_flash_causal_attn_decode", "cpu")
|
64
|
+
def flash_attn_decode_cpu(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
|
65
|
+
return q
|
66
|
+
|
67
|
+
@register_fake("rbln_custom_ops::paged_flash_causal_attn_decode")
|
68
|
+
def flash_attn_decode_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
|
69
|
+
return q
|
70
|
+
|
71
|
+
torch.library.define(
|
72
|
+
"rbln_custom_ops::paged_flash_causal_attn_prefill",
|
73
|
+
"(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
|
74
|
+
)
|
75
|
+
|
76
|
+
@torch.library.impl("rbln_custom_ops::paged_flash_causal_attn_prefill", "cpu")
|
77
|
+
def flash_attn_prefill_cpu(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
|
78
|
+
return q
|
58
79
|
|
59
|
-
@register_fake("rbln_custom_ops::
|
60
|
-
def flash_attn_prefill_abstract(q, k, v,
|
61
|
-
return q
|
80
|
+
@register_fake("rbln_custom_ops::paged_flash_causal_attn_prefill")
|
81
|
+
def flash_attn_prefill_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
|
82
|
+
return q
|
@@ -45,10 +45,10 @@ def register_rbln_custom_cache_update():
|
|
45
45
|
|
46
46
|
# Update the specified portion of the cache tensor with the state tensor, using `slice_scatter`.
|
47
47
|
# This operation modifies the cache tensor in-place directly on the device, avoiding any unnecessary transfers between host and device.
|
48
|
-
|
48
|
+
cache.slice_scatter(state, dim=axis, start=s, end=e)
|
49
49
|
|
50
|
-
#
|
51
|
-
return
|
50
|
+
# 'rbln_cache_update' is an in-place operation that isn't tracked in JIT trace, so a dummy output was added to the return value.
|
51
|
+
return torch.empty([256])
|
52
52
|
|
53
53
|
# Register a "fake" implementation of the "rbln_cache_update" operation.
|
54
54
|
# This serves as an abstract definition for the RBLN compiler to recognize the operation and generate an optimized implementation.
|
@@ -57,4 +57,4 @@ def register_rbln_custom_cache_update():
|
|
57
57
|
# Return a tensor with the same shape as the input cache tensor.
|
58
58
|
# This is a placeholder for the abstract implementation and does not perform any actual computation.
|
59
59
|
# Like the actual implementation, the abstraction assumes in-place device-side updates.
|
60
|
-
return torch.
|
60
|
+
return torch.empty([256])
|
@@ -73,7 +73,7 @@ class RBLNModelForQuestionAnswering(RBLNModel):
|
|
73
73
|
if rbln_batch_size is None:
|
74
74
|
rbln_batch_size = 1
|
75
75
|
|
76
|
-
signature_params = inspect.signature(cls.
|
76
|
+
signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
|
77
77
|
|
78
78
|
if rbln_model_input_names is None:
|
79
79
|
for tokenizer in preprocessors:
|
@@ -289,7 +289,7 @@ class RBLNModelForSequenceClassification(RBLNModel):
|
|
289
289
|
if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
|
290
290
|
raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
|
291
291
|
|
292
|
-
signature_params = inspect.signature(cls.
|
292
|
+
signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
|
293
293
|
|
294
294
|
if rbln_model_input_names is None:
|
295
295
|
for tokenizer in preprocessors:
|
@@ -362,7 +362,7 @@ class RBLNModelForMaskedLM(RBLNModel):
|
|
362
362
|
if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
|
363
363
|
raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
|
364
364
|
|
365
|
-
signature_params = inspect.signature(cls.
|
365
|
+
signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
|
366
366
|
|
367
367
|
if rbln_model_input_names is None:
|
368
368
|
for tokenizer in preprocessors:
|
@@ -35,16 +35,16 @@ logger = logging.get_logger(__name__)
|
|
35
35
|
|
36
36
|
|
37
37
|
class BartWrapper:
|
38
|
-
def __init__(self, model: nn.Module, enc_max_seq_len: int):
|
38
|
+
def __init__(self, model: nn.Module, enc_max_seq_len: int, use_attention_mask: bool):
|
39
39
|
self.encoder = Seq2SeqEncoderWrapper(model, enc_max_seq_len)
|
40
|
-
self.decoder = BartDecoderWrapper(model)
|
40
|
+
self.decoder = BartDecoderWrapper(model, use_attention_mask=use_attention_mask)
|
41
41
|
|
42
42
|
|
43
43
|
class BartDecoderWrapper(Seq2SeqDecoderWrapper):
|
44
44
|
def convert_to_rbln_conditional_generation(self, model: nn.Module):
|
45
45
|
new_layers = []
|
46
46
|
for layer in model.get_decoder().layers:
|
47
|
-
self_attn = BartSelfAttention(layer.self_attn)
|
47
|
+
self_attn = BartSelfAttention(layer.self_attn, use_attention_mask=self.use_attention_mask)
|
48
48
|
new_layers.append(BartDecoderLayer(layer, self_attn))
|
49
49
|
|
50
50
|
decoder_model = BartDecoder(model.get_decoder(), new_layers)
|
@@ -69,7 +69,8 @@ class BartDecoder(Seq2SeqDecoder):
|
|
69
69
|
self.embed_scale = getattr(self._original_mod, "embed_scale", None)
|
70
70
|
|
71
71
|
def prepare_attn_mask(self, attention_mask, encoder_attention_mask, **kwargs):
|
72
|
-
attention_mask
|
72
|
+
if attention_mask is not None:
|
73
|
+
attention_mask = attention_mask[:, None, None, :]
|
73
74
|
encoder_attention_mask = _prepare_4d_attention_mask(encoder_attention_mask, torch.float32, tgt_len=1)
|
74
75
|
|
75
76
|
return attention_mask, encoder_attention_mask
|
@@ -134,7 +135,7 @@ class BartDecoderLayer(Seq2SeqDecoderLayer):
|
|
134
135
|
|
135
136
|
|
136
137
|
class BartSelfAttention(Seq2SeqSelfAttention):
|
137
|
-
def __post_init__(self):
|
138
|
+
def __post_init__(self, use_attention_mask: bool = True):
|
138
139
|
self.q_proj = self._original_mod.q_proj
|
139
140
|
self.k_proj = self._original_mod.k_proj
|
140
141
|
self.v_proj = self._original_mod.v_proj
|
@@ -142,7 +143,10 @@ class BartSelfAttention(Seq2SeqSelfAttention):
|
|
142
143
|
self.num_heads = self._original_mod.num_heads
|
143
144
|
self.head_dim = self._original_mod.embed_dim // self._original_mod.num_heads
|
144
145
|
self.scaling = self.head_dim**-0.5
|
145
|
-
|
146
|
+
if use_attention_mask:
|
147
|
+
self.attn_decode = torch.ops.rbln_custom_ops.paged_attn_decode
|
148
|
+
else:
|
149
|
+
self.attn_decode = torch.ops.rbln_custom_ops.paged_causal_attn_decode
|
146
150
|
|
147
151
|
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
148
152
|
query_states = self.q_proj(hidden_states) * self.scaling
|
@@ -58,7 +58,7 @@ class RBLNBartModel(RBLNModel):
|
|
58
58
|
if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
|
59
59
|
raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
|
60
60
|
|
61
|
-
signature_params = inspect.signature(cls.
|
61
|
+
signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
|
62
62
|
|
63
63
|
if rbln_model_input_names is None:
|
64
64
|
for tokenizer in preprocessors:
|
@@ -108,12 +108,16 @@ class RBLNBartModel(RBLNModel):
|
|
108
108
|
|
109
109
|
|
110
110
|
class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
111
|
+
support_paged_causal_attn = True
|
112
|
+
|
111
113
|
@classmethod
|
112
114
|
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
113
115
|
enc_max_seq_len = (
|
114
116
|
rbln_config.model_cfg["enc_max_seq_len"] if "enc_max_seq_len" in rbln_config.model_cfg else 1024
|
115
117
|
)
|
116
|
-
|
118
|
+
use_attention_mask = rbln_config.model_cfg.get("use_attention_mask", False)
|
119
|
+
|
120
|
+
return BartWrapper(model, enc_max_seq_len=enc_max_seq_len, use_attention_mask=use_attention_mask)
|
117
121
|
|
118
122
|
def __getattr__(self, __name: str) -> Any:
|
119
123
|
def redirect(func):
|
@@ -56,7 +56,7 @@ class RBLNBertModel(RBLNModel):
|
|
56
56
|
if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
|
57
57
|
raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
|
58
58
|
|
59
|
-
signature_params = inspect.signature(cls.
|
59
|
+
signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
|
60
60
|
|
61
61
|
if rbln_model_input_names is None:
|
62
62
|
for tokenizer in preprocessors:
|