optimum-rbln 0.7.4a2__py3-none-any.whl → 0.7.4a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/modeling.py +8 -1
- optimum/rbln/modeling_base.py +0 -5
- optimum/rbln/ops/__init__.py +3 -7
- optimum/rbln/ops/attn.py +271 -207
- optimum/rbln/ops/flash_attn.py +161 -67
- optimum/rbln/ops/kv_cache_update.py +4 -40
- optimum/rbln/transformers/models/bart/__init__.py +1 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +10 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +80 -94
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +17 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +12 -21
- optimum/rbln/transformers/models/t5/__init__.py +1 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +3 -37
- optimum/rbln/transformers/models/t5/t5_architecture.py +3 -4
- optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
- optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py +12 -22
- optimum/rbln/transformers/models/whisper/__init__.py +1 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +0 -1
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +20 -32
- {optimum_rbln-0.7.4a2.dist-info → optimum_rbln-0.7.4a4.dist-info}/METADATA +1 -1
- {optimum_rbln-0.7.4a2.dist-info → optimum_rbln-0.7.4a4.dist-info}/RECORD +24 -24
- {optimum_rbln-0.7.4a2.dist-info → optimum_rbln-0.7.4a4.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.4a2.dist-info → optimum_rbln-0.7.4a4.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__version__.py
CHANGED
optimum/rbln/modeling.py
CHANGED
@@ -123,8 +123,15 @@ class RBLNModel(RBLNBaseModel):
|
|
123
123
|
config = AutoConfig.from_pretrained(config._name_or_path, **kwargs)
|
124
124
|
|
125
125
|
if hasattr(model, "can_generate") and model.can_generate():
|
126
|
+
import json
|
127
|
+
|
126
128
|
generation_config = model.generation_config
|
127
|
-
|
129
|
+
generation_config_path = save_dir_path / subfolder / "generation_config.json"
|
130
|
+
|
131
|
+
generation_config.save_pretrained(generation_config_path.parent)
|
132
|
+
local_config = json.loads(generation_config_path.read_text(encoding="utf-8"))
|
133
|
+
local_config["transformers_version"] = generation_config.transformers_version
|
134
|
+
generation_config_path.write_text(json.dumps(local_config, indent=2) + "\n", encoding="utf-8")
|
128
135
|
|
129
136
|
if not isinstance(config, PretrainedConfig): # diffusers config
|
130
137
|
config = PretrainedConfig(**config)
|
optimum/rbln/modeling_base.py
CHANGED
@@ -481,11 +481,6 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
481
481
|
# First copy everything to a temporary directory
|
482
482
|
shutil.copytree(real_save_dir, tmp_dir)
|
483
483
|
|
484
|
-
# Save configs to the temporary directory
|
485
|
-
self.config.save_pretrained(tmp_dir)
|
486
|
-
if self.generation_config is not None:
|
487
|
-
self.generation_config.save_pretrained(tmp_dir)
|
488
|
-
|
489
484
|
# If everything succeeded, atomically replace the target directory
|
490
485
|
if os.path.exists(save_directory_path):
|
491
486
|
shutil.rmtree(save_directory_path)
|
optimum/rbln/ops/__init__.py
CHANGED
@@ -12,11 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from .attn import
|
16
|
-
|
17
|
-
|
18
|
-
register_rbln_custom_paged_causal_attention,
|
19
|
-
)
|
20
|
-
from .flash_attn import register_rbln_custom_paged_flash_attention, register_rbln_custom_paged_flash_causal_attention
|
21
|
-
from .kv_cache_update import register_rbln_custom_cache_update
|
15
|
+
from .attn import *
|
16
|
+
from .flash_attn import *
|
17
|
+
from .kv_cache_update import *
|
22
18
|
from .linear import linear
|
optimum/rbln/ops/attn.py
CHANGED
@@ -12,212 +12,276 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from functools import lru_cache
|
16
15
|
|
17
16
|
import torch
|
18
|
-
from
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
17
|
+
from torch import Tensor
|
18
|
+
|
19
|
+
|
20
|
+
@torch.library.custom_op(
|
21
|
+
"rbln_custom_ops::paged_attn_decode",
|
22
|
+
mutates_args=(["kcache", "vcache"]),
|
23
|
+
)
|
24
|
+
def paged_attn_decode(
|
25
|
+
q: Tensor,
|
26
|
+
k: Tensor,
|
27
|
+
v: Tensor,
|
28
|
+
mask: Tensor,
|
29
|
+
kcache: Tensor,
|
30
|
+
vcache: Tensor,
|
31
|
+
seq: Tensor,
|
32
|
+
scale: Tensor,
|
33
|
+
block_table: Tensor,
|
34
|
+
block_size: int,
|
35
|
+
) -> Tensor:
|
36
|
+
return torch.empty_like(q)
|
37
|
+
|
38
|
+
|
39
|
+
@paged_attn_decode.register_fake
|
40
|
+
def paged_attn_decode_fake(
|
41
|
+
q: Tensor,
|
42
|
+
k: Tensor,
|
43
|
+
v: Tensor,
|
44
|
+
mask: Tensor,
|
45
|
+
kcache: Tensor,
|
46
|
+
vcache: Tensor,
|
47
|
+
seq: Tensor,
|
48
|
+
scale: Tensor,
|
49
|
+
block_table: Tensor,
|
50
|
+
block_size: int,
|
51
|
+
) -> Tensor:
|
52
|
+
return torch.empty_like(q)
|
53
|
+
|
54
|
+
|
55
|
+
@torch.library.custom_op(
|
56
|
+
"rbln_custom_ops::paged_attn_prefill",
|
57
|
+
mutates_args=(["kcache", "vcache"]),
|
58
|
+
)
|
59
|
+
def paged_attn_prefill(
|
60
|
+
q: Tensor,
|
61
|
+
k: Tensor,
|
62
|
+
v: Tensor,
|
63
|
+
mask: Tensor,
|
64
|
+
kcache: Tensor,
|
65
|
+
vcache: Tensor,
|
66
|
+
seq: Tensor,
|
67
|
+
scale: Tensor,
|
68
|
+
block_table: Tensor,
|
69
|
+
block_size: int,
|
70
|
+
) -> Tensor:
|
71
|
+
"""Defines the computation pattern for prefill phase attention with KV cache updates.
|
72
|
+
|
73
|
+
IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
|
74
|
+
a single optimized NPU operation. It is NOT meant for CPU execution.
|
75
|
+
|
76
|
+
Key differences from decode pattern:
|
77
|
+
- Handles prefill phase with multiple input tokens
|
78
|
+
- Takes explicit batch index for continuous batching
|
79
|
+
|
80
|
+
Expected tensor shapes:
|
81
|
+
- q: [batch=1, n_heads, n_groups, seq_len, head_dim] - Query states for multiple tokens
|
82
|
+
- k: [batch=1, n_heads, 1, seq_len, head_dim] - Key states for current input
|
83
|
+
- v: [batch=1, n_heads, 1, seq_len, head_dim] - Value states for current input
|
84
|
+
- mask: [batch=1, 1, 1, seq_len, max_seq_len] - Attention mask
|
85
|
+
- kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
|
86
|
+
- vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
|
87
|
+
- seq: [1, 1] - Starting sequence position
|
88
|
+
- scale: [] - Attention scale factor
|
89
|
+
- block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
|
90
|
+
- block_size: [] - Number of tokens per block
|
91
|
+
|
92
|
+
Returns:
|
93
|
+
Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
|
94
|
+
"""
|
95
|
+
return torch.empty_like(q)
|
96
|
+
|
97
|
+
|
98
|
+
@paged_attn_prefill.register_fake
|
99
|
+
def paged_attn_prefill_fake(
|
100
|
+
q: Tensor,
|
101
|
+
k: Tensor,
|
102
|
+
v: Tensor,
|
103
|
+
mask: Tensor,
|
104
|
+
kcache: Tensor,
|
105
|
+
vcache: Tensor,
|
106
|
+
seq: Tensor,
|
107
|
+
scale: Tensor,
|
108
|
+
block_table: Tensor,
|
109
|
+
block_size: int,
|
110
|
+
) -> Tensor:
|
111
|
+
return torch.empty_like(q)
|
112
|
+
|
113
|
+
|
114
|
+
@torch.library.custom_op(
|
115
|
+
"rbln_custom_ops::paged_causal_attn_decode",
|
116
|
+
mutates_args=(["kcache", "vcache"]),
|
117
|
+
)
|
118
|
+
def paged_causal_attn_decode(
|
119
|
+
q: Tensor,
|
120
|
+
k: Tensor,
|
121
|
+
v: Tensor,
|
122
|
+
kcache: Tensor,
|
123
|
+
vcache: Tensor,
|
124
|
+
seq: Tensor,
|
125
|
+
scale: Tensor,
|
126
|
+
block_table: Tensor,
|
127
|
+
block_size: int,
|
128
|
+
) -> Tensor:
|
129
|
+
"""Defines the computation pattern for fused attention with KV cache updates.
|
130
|
+
|
131
|
+
IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
|
132
|
+
a single optimized NPU operation. It is NOT meant for CPU execution.
|
133
|
+
|
134
|
+
Pattern components that compiler fuses into a single op:
|
135
|
+
1. KV cache updates with new key/value states
|
136
|
+
2. Scaled dot-product attention computation
|
137
|
+
3. Causal masked softmax operation
|
138
|
+
4. Final attention output computation
|
139
|
+
|
140
|
+
Expected tensor shapes:
|
141
|
+
- q: [batch=1, n_heads, n_groups, 1, head_dim] - Query states for single token
|
142
|
+
- k: [batch=1, n_heads, 1, 1, head_dim] - Key states for current input
|
143
|
+
- v: [batch=1, n_heads, 1, 1, head_dim] - Value states for current input
|
144
|
+
- kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
|
145
|
+
- vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
|
146
|
+
- seq: [1, 1] - Starting sequence position
|
147
|
+
- scale: [] - Attention scale factor
|
148
|
+
- block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
|
149
|
+
- block_size: [] - Number of tokens per block
|
150
|
+
|
151
|
+
Returns:
|
152
|
+
Tensor: attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
|
153
|
+
"""
|
154
|
+
return torch.empty_like(q)
|
155
|
+
|
156
|
+
|
157
|
+
@paged_causal_attn_decode.register_fake
|
158
|
+
def paged_causal_attn_decode_fake(
|
159
|
+
q: Tensor,
|
160
|
+
k: Tensor,
|
161
|
+
v: Tensor,
|
162
|
+
kcache: Tensor,
|
163
|
+
vcache: Tensor,
|
164
|
+
seq: Tensor,
|
165
|
+
scale: Tensor,
|
166
|
+
block_table: Tensor,
|
167
|
+
block_size: int,
|
168
|
+
) -> Tensor:
|
169
|
+
return torch.empty_like(q)
|
170
|
+
|
171
|
+
|
172
|
+
@torch.library.custom_op(
|
173
|
+
"rbln_custom_ops::paged_causal_attn_prefill",
|
174
|
+
mutates_args=(["kcache", "vcache"]),
|
175
|
+
)
|
176
|
+
def paged_causal_attn_prefill(
|
177
|
+
q: Tensor,
|
178
|
+
k: Tensor,
|
179
|
+
v: Tensor,
|
180
|
+
kcache: Tensor,
|
181
|
+
vcache: Tensor,
|
182
|
+
seq: Tensor,
|
183
|
+
scale: Tensor,
|
184
|
+
block_table: Tensor,
|
185
|
+
block_size: int,
|
186
|
+
) -> Tensor:
|
187
|
+
"""Defines the computation pattern for prefill phase attention with KV cache updates.
|
188
|
+
|
189
|
+
IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
|
190
|
+
a single optimized NPU operation. It is NOT meant for CPU execution.
|
191
|
+
|
192
|
+
Key differences from decode pattern:
|
193
|
+
- Handles prefill phase with multiple input tokens
|
194
|
+
- Takes explicit batch index for continuous batching
|
195
|
+
|
196
|
+
Expected tensor shapes:
|
197
|
+
- q: [batch=1, n_heads, n_groups, seq_len, head_dim] - Query states for multiple tokens
|
198
|
+
- k: [batch=1, n_heads, 1, seq_len, head_dim] - Key states for current input
|
199
|
+
- v: [batch=1, n_heads, 1, seq_len, head_dim] - Value states for current input
|
200
|
+
- kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
|
201
|
+
- vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
|
202
|
+
- batch: [1] - Batch index for cache access
|
203
|
+
- seq: [1, 1] - Starting sequence position
|
204
|
+
- scale: [] - Attention scale factor
|
205
|
+
- block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
|
206
|
+
- block_size: [] - Number of tokens per block
|
207
|
+
|
208
|
+
Returns:
|
209
|
+
Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
|
210
|
+
"""
|
211
|
+
return torch.empty_like(q)
|
212
|
+
|
213
|
+
|
214
|
+
@paged_causal_attn_prefill.register_fake
|
215
|
+
def paged_causal_attn_prefill_fake(
|
216
|
+
q: Tensor,
|
217
|
+
k: Tensor,
|
218
|
+
v: Tensor,
|
219
|
+
kcache: Tensor,
|
220
|
+
vcache: Tensor,
|
221
|
+
seq: Tensor,
|
222
|
+
scale: Tensor,
|
223
|
+
block_table: Tensor,
|
224
|
+
block_size: int,
|
225
|
+
) -> Tensor:
|
226
|
+
return torch.empty_like(q)
|
227
|
+
|
228
|
+
|
229
|
+
@torch.library.custom_op(
|
230
|
+
"rbln_custom_ops::paged_add_softmax_attn_decode",
|
231
|
+
mutates_args=(["kcache", "vcache"]),
|
232
|
+
)
|
233
|
+
def paged_add_softmax_attn_decode(
|
234
|
+
q: Tensor,
|
235
|
+
k: Tensor,
|
236
|
+
v: Tensor,
|
237
|
+
mask: Tensor,
|
238
|
+
kcache: Tensor,
|
239
|
+
vcache: Tensor,
|
240
|
+
seq: Tensor,
|
241
|
+
scale: Tensor,
|
242
|
+
block_table: Tensor,
|
243
|
+
block_size: int,
|
244
|
+
) -> Tensor:
|
245
|
+
"""Defines the computation pattern for fused attention with KV cache updates.
|
246
|
+
|
247
|
+
IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
|
248
|
+
a single optimized NPU operation. It is NOT meant for CPU execution.
|
249
|
+
|
250
|
+
Pattern components that compiler fuses into a single op:
|
251
|
+
1. KV cache updates with new key/value states
|
252
|
+
2. Scaled dot-product attention computation
|
253
|
+
3. add-softmax operation
|
254
|
+
4. Final attention output computation
|
255
|
+
|
256
|
+
Expected tensor shapes:
|
257
|
+
- q: [batch=1, n_heads, n_groups, 1, head_dim] - Query states for single token
|
258
|
+
- k: [batch=1, n_heads, 1, 1, head_dim] - Key states for current input
|
259
|
+
- v: [batch=1, n_heads, 1, 1, head_dim] - Value states for current input
|
260
|
+
- mask: [batch=1, n_heads, 1, 1, max_seq_len] - Attention mask
|
261
|
+
- kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
|
262
|
+
- vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
|
263
|
+
- seq: [1] - Current sequence position
|
264
|
+
- scale: [] - Attention scale factor
|
265
|
+
- block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
|
266
|
+
- block_size: [] - Number of tokens per block
|
267
|
+
|
268
|
+
Returns:
|
269
|
+
Tensor: attn_output: [batch=1, n_heads, 1, 1, head_dim] - Attention output
|
270
|
+
"""
|
271
|
+
return torch.empty_like(q)
|
272
|
+
|
273
|
+
|
274
|
+
@paged_add_softmax_attn_decode.register_fake
|
275
|
+
def paged_add_softmax_attn_decode_fake(
|
276
|
+
q: Tensor,
|
277
|
+
k: Tensor,
|
278
|
+
v: Tensor,
|
279
|
+
mask: Tensor,
|
280
|
+
kcache: Tensor,
|
281
|
+
vcache: Tensor,
|
282
|
+
seq: Tensor,
|
283
|
+
scale: Tensor,
|
284
|
+
block_table: Tensor,
|
285
|
+
block_size: int,
|
286
|
+
) -> Tensor:
|
287
|
+
return torch.empty_like(q)
|