optimum-rbln 0.7.3.post2__py3-none-any.whl → 0.7.4a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/ops/__init__.py +1 -1
- optimum/rbln/ops/attn.py +9 -7
- optimum/rbln/transformers/models/bart/modeling_bart.py +2 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +12 -12
- optimum/rbln/transformers/models/t5/modeling_t5.py +3 -210
- optimum/rbln/transformers/models/t5/t5_architecture.py +9 -3
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +12 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +13 -3
- {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4a0.dist-info}/METADATA +1 -1
- {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4a0.dist-info}/RECORD +13 -13
- {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4a0.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4a0.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__version__.py
CHANGED
optimum/rbln/ops/__init__.py
CHANGED
optimum/rbln/ops/attn.py
CHANGED
@@ -182,14 +182,14 @@ def register_rbln_custom_paged_causal_attention():
|
|
182
182
|
|
183
183
|
|
184
184
|
@lru_cache
|
185
|
-
def
|
185
|
+
def register_rbln_custom_paged_add_softmax_attention():
|
186
186
|
torch.library.define(
|
187
|
-
"rbln_custom_ops::
|
188
|
-
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor",
|
187
|
+
"rbln_custom_ops::paged_add_softmax_attn_decode",
|
188
|
+
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor",
|
189
189
|
)
|
190
190
|
|
191
|
-
@torch.library.impl("rbln_custom_ops::
|
192
|
-
def
|
191
|
+
@torch.library.impl("rbln_custom_ops::paged_add_softmax_attn_decode", "cpu")
|
192
|
+
def paged_add_softmax_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size):
|
193
193
|
"""Defines the computation pattern for fused attention with KV cache updates.
|
194
194
|
|
195
195
|
IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
|
@@ -210,12 +210,14 @@ def register_rbln_custom_add_softmax_attention():
|
|
210
210
|
- vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
|
211
211
|
- seq: [1] - Current sequence position
|
212
212
|
- scale: [] - Attention scale factor
|
213
|
+
- block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
|
214
|
+
- block_size: [] - Number of tokens per block
|
213
215
|
|
214
216
|
Returns:
|
215
217
|
Tensor: attn_output: [batch=1, n_heads, 1, 1, head_dim] - Attention output
|
216
218
|
"""
|
217
219
|
return q
|
218
220
|
|
219
|
-
@register_fake("rbln_custom_ops::
|
220
|
-
def
|
221
|
+
@register_fake("rbln_custom_ops::paged_add_softmax_attn_decode")
|
222
|
+
def paged_add_softmax_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, partition, block_table, block_size):
|
221
223
|
return q
|
@@ -108,6 +108,8 @@ class RBLNBartModel(RBLNModel):
|
|
108
108
|
|
109
109
|
|
110
110
|
class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
111
|
+
support_causal_attn = True
|
112
|
+
|
111
113
|
@classmethod
|
112
114
|
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
113
115
|
enc_max_seq_len = (
|
@@ -94,7 +94,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
94
94
|
decoder_attention_mask if self.use_attention_mask else None,
|
95
95
|
attention_mask,
|
96
96
|
cache_position,
|
97
|
-
block_tables,
|
97
|
+
block_tables=block_tables,
|
98
98
|
)
|
99
99
|
|
100
100
|
return Seq2SeqLMOutput(logits=lm_logits)
|
@@ -115,6 +115,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
115
115
|
|
116
116
|
main_input_name = "input_ids"
|
117
117
|
auto_model_class = AutoModelForSeq2SeqLM
|
118
|
+
support_causal_attn = None
|
118
119
|
|
119
120
|
def __post_init__(self, **kwargs):
|
120
121
|
batch_size = self.rbln_config.model_cfg["batch_size"]
|
@@ -186,13 +187,16 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
186
187
|
rbln_dec_max_seq_len = rbln_kwargs.get("dec_max_seq_len", None)
|
187
188
|
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
188
189
|
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
189
|
-
rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
|
190
190
|
|
191
|
-
if
|
192
|
-
rbln_use_attention_mask =
|
193
|
-
|
194
|
-
|
195
|
-
|
191
|
+
if cls.support_causal_attn:
|
192
|
+
rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
|
193
|
+
if rbln_use_attention_mask is None:
|
194
|
+
rbln_use_attention_mask = False
|
195
|
+
rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
|
196
|
+
if rbln_npu == "RBLN-CA02":
|
197
|
+
rbln_use_attention_mask = True
|
198
|
+
else:
|
199
|
+
rbln_use_attention_mask = True
|
196
200
|
|
197
201
|
n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
|
198
202
|
n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
|
@@ -265,11 +269,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
265
269
|
[rbln_batch_size, 1],
|
266
270
|
"int32",
|
267
271
|
),
|
268
|
-
(
|
269
|
-
"block_tables",
|
270
|
-
[rbln_batch_size, 1],
|
271
|
-
"int16",
|
272
|
-
),
|
272
|
+
("block_tables", [rbln_batch_size, 1], "int16"),
|
273
273
|
]
|
274
274
|
dec_input_info.extend(
|
275
275
|
[
|
@@ -13,9 +13,8 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import inspect
|
16
|
-
from typing import TYPE_CHECKING, Any, Callable, Dict,
|
16
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
|
17
17
|
|
18
|
-
import rebel
|
19
18
|
import torch
|
20
19
|
from transformers import (
|
21
20
|
AutoModelForTextEncoding,
|
@@ -23,7 +22,7 @@ from transformers import (
|
|
23
22
|
T5EncoderModel,
|
24
23
|
T5ForConditionalGeneration,
|
25
24
|
)
|
26
|
-
from transformers.modeling_outputs import BaseModelOutput
|
25
|
+
from transformers.modeling_outputs import BaseModelOutput
|
27
26
|
|
28
27
|
from ....diffusers.modeling_diffusers import RBLNDiffusionMixin
|
29
28
|
from ....modeling import RBLNModel
|
@@ -58,63 +57,6 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
58
57
|
)
|
59
58
|
|
60
59
|
|
61
|
-
class RBLNRuntimeEncoder(RBLNPytorchRuntime):
|
62
|
-
mandatory_members = ["main_input_name"]
|
63
|
-
|
64
|
-
def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
|
65
|
-
_ = super().forward(*args, **kwargs)
|
66
|
-
return BaseModelOutput(last_hidden_state=torch.tensor([1.0]))
|
67
|
-
|
68
|
-
|
69
|
-
class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
70
|
-
mandatory_members = ["main_input_name"]
|
71
|
-
|
72
|
-
def __init__(
|
73
|
-
self,
|
74
|
-
runtime: rebel.Runtime,
|
75
|
-
batch_size: int,
|
76
|
-
dec_max_seq_len: int,
|
77
|
-
**kwargs: Any,
|
78
|
-
) -> None:
|
79
|
-
super().__init__(runtime, **kwargs)
|
80
|
-
self.batch_size = batch_size
|
81
|
-
self.dec_max_seq_len = dec_max_seq_len
|
82
|
-
|
83
|
-
def forward(
|
84
|
-
self,
|
85
|
-
decoder_input_ids: Optional[torch.LongTensor] = None,
|
86
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
87
|
-
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
88
|
-
cache_position: Optional[torch.Tensor] = None,
|
89
|
-
**kwargs,
|
90
|
-
) -> Tuple[torch.FloatTensor]:
|
91
|
-
batch_size = decoder_input_ids.shape[0]
|
92
|
-
if batch_size != self.batch_size:
|
93
|
-
raise RuntimeError(
|
94
|
-
f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
|
95
|
-
)
|
96
|
-
|
97
|
-
if batch_size != cache_position.shape[0]:
|
98
|
-
raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
|
99
|
-
|
100
|
-
for b_idx in range(self.batch_size):
|
101
|
-
decoding_step = cache_position[b_idx].item()
|
102
|
-
if not (0 <= decoding_step < self.dec_max_seq_len):
|
103
|
-
raise ValueError(
|
104
|
-
f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
|
105
|
-
)
|
106
|
-
decoder_attention_mask[b_idx, : decoding_step + 1] = 1
|
107
|
-
|
108
|
-
lm_logits = super().forward(
|
109
|
-
decoder_input_ids,
|
110
|
-
decoder_attention_mask,
|
111
|
-
attention_mask,
|
112
|
-
cache_position,
|
113
|
-
)
|
114
|
-
|
115
|
-
return Seq2SeqLMOutput(logits=lm_logits)
|
116
|
-
|
117
|
-
|
118
60
|
class T5EncoderWrapper(torch.nn.Module):
|
119
61
|
def __init__(self, model: "T5EncoderModel") -> None:
|
120
62
|
super().__init__()
|
@@ -247,20 +189,7 @@ class RBLNT5EncoderModel(RBLNModel):
|
|
247
189
|
|
248
190
|
|
249
191
|
class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
250
|
-
|
251
|
-
batch_size = self.rbln_config.model_cfg["batch_size"]
|
252
|
-
dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
|
253
|
-
|
254
|
-
self.encoder = RBLNRuntimeEncoder(
|
255
|
-
runtime=self.model[0],
|
256
|
-
main_input_name="input_ids",
|
257
|
-
)
|
258
|
-
self.decoder = RBLNRuntimeDecoder(
|
259
|
-
runtime=self.model[1],
|
260
|
-
main_input_name="input_ids",
|
261
|
-
batch_size=batch_size,
|
262
|
-
dec_max_seq_len=dec_max_seq_len,
|
263
|
-
)
|
192
|
+
support_causal_attn = False
|
264
193
|
|
265
194
|
@classmethod
|
266
195
|
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
@@ -279,139 +208,3 @@ class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
|
279
208
|
return redirect(val)
|
280
209
|
|
281
210
|
return val
|
282
|
-
|
283
|
-
@classmethod
|
284
|
-
def _get_rbln_config(
|
285
|
-
cls,
|
286
|
-
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
287
|
-
model_config: "PretrainedConfig",
|
288
|
-
rbln_kwargs: Dict[str, Any] = {},
|
289
|
-
) -> RBLNConfig:
|
290
|
-
rbln_enc_max_seq_len = rbln_kwargs.get("enc_max_seq_len", None)
|
291
|
-
rbln_dec_max_seq_len = rbln_kwargs.get("dec_max_seq_len", None)
|
292
|
-
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
293
|
-
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
294
|
-
|
295
|
-
n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
|
296
|
-
n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
|
297
|
-
d_kv = (
|
298
|
-
model_config.d_kv
|
299
|
-
if hasattr(model_config, "d_kv")
|
300
|
-
else model_config.d_model // model_config.encoder_attention_heads
|
301
|
-
)
|
302
|
-
|
303
|
-
max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
|
304
|
-
model_config, "max_position_embeddings", None
|
305
|
-
)
|
306
|
-
|
307
|
-
rbln_pad_token_id = getattr(model_config, "pad_token_id", None)
|
308
|
-
if rbln_pad_token_id is None:
|
309
|
-
rbln_pad_token_id = getattr(model_config, "bos_token_id", None)
|
310
|
-
if rbln_pad_token_id is None:
|
311
|
-
rbln_pad_token_id = getattr(model_config, "eos_token_id", None)
|
312
|
-
if rbln_pad_token_id is None:
|
313
|
-
rbln_pad_token_id = -1
|
314
|
-
|
315
|
-
if rbln_enc_max_seq_len is None:
|
316
|
-
rbln_enc_max_seq_len = max_position_embeddings
|
317
|
-
if rbln_enc_max_seq_len is None:
|
318
|
-
for tokenizer in preprocessors:
|
319
|
-
if hasattr(tokenizer, "model_max_length"):
|
320
|
-
rbln_enc_max_seq_len = tokenizer.model_max_length
|
321
|
-
break
|
322
|
-
if rbln_enc_max_seq_len is None:
|
323
|
-
raise ValueError("`rbln_enc_max_seq_len` should be specified!")
|
324
|
-
if max_position_embeddings is not None and rbln_enc_max_seq_len > max_position_embeddings:
|
325
|
-
raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
|
326
|
-
|
327
|
-
if rbln_dec_max_seq_len is None:
|
328
|
-
rbln_dec_max_seq_len = max_position_embeddings
|
329
|
-
if rbln_dec_max_seq_len is None:
|
330
|
-
for tokenizer in preprocessors:
|
331
|
-
if hasattr(tokenizer, "model_max_length"):
|
332
|
-
rbln_dec_max_seq_len = tokenizer.model_max_length
|
333
|
-
break
|
334
|
-
if rbln_dec_max_seq_len is None:
|
335
|
-
raise ValueError("`rbln_dec_max_seq_len` should be specified!")
|
336
|
-
|
337
|
-
if max_position_embeddings is not None and rbln_dec_max_seq_len > max_position_embeddings:
|
338
|
-
raise ValueError("`rbln_dec_max_seq_len` should be less or equal than max_position_embeddings!")
|
339
|
-
|
340
|
-
# model input info
|
341
|
-
enc_input_info = [
|
342
|
-
("input_ids", [1, rbln_enc_max_seq_len], "int64"),
|
343
|
-
("attention_mask", [1, rbln_enc_max_seq_len], "float32"),
|
344
|
-
(
|
345
|
-
"cross_key_value_states",
|
346
|
-
[
|
347
|
-
n_layer * 2,
|
348
|
-
rbln_batch_size,
|
349
|
-
n_head,
|
350
|
-
rbln_enc_max_seq_len,
|
351
|
-
d_kv,
|
352
|
-
],
|
353
|
-
"float32",
|
354
|
-
),
|
355
|
-
("block_tables", [1], "int16"),
|
356
|
-
]
|
357
|
-
|
358
|
-
dec_input_info = [
|
359
|
-
("input_ids", [rbln_batch_size, 1], "int64"),
|
360
|
-
("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"),
|
361
|
-
("encoder_attention_mask", [rbln_batch_size, rbln_enc_max_seq_len], "float32"),
|
362
|
-
(
|
363
|
-
"cache_position",
|
364
|
-
[rbln_batch_size, 1],
|
365
|
-
"int32",
|
366
|
-
),
|
367
|
-
]
|
368
|
-
dec_input_info.extend(
|
369
|
-
[
|
370
|
-
(
|
371
|
-
"cross_key_value_states",
|
372
|
-
[
|
373
|
-
n_layer * 2,
|
374
|
-
rbln_batch_size,
|
375
|
-
n_head,
|
376
|
-
rbln_enc_max_seq_len,
|
377
|
-
d_kv,
|
378
|
-
],
|
379
|
-
"float32",
|
380
|
-
)
|
381
|
-
]
|
382
|
-
)
|
383
|
-
dec_input_info.extend(
|
384
|
-
[
|
385
|
-
(
|
386
|
-
f"self_key_value_states_{i}",
|
387
|
-
[
|
388
|
-
rbln_batch_size,
|
389
|
-
n_head,
|
390
|
-
rbln_dec_max_seq_len,
|
391
|
-
d_kv,
|
392
|
-
],
|
393
|
-
"float32",
|
394
|
-
)
|
395
|
-
for i in range(n_layer * 2)
|
396
|
-
]
|
397
|
-
)
|
398
|
-
|
399
|
-
enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
|
400
|
-
dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
|
401
|
-
|
402
|
-
rbln_config = RBLNConfig(
|
403
|
-
rbln_cls=cls.__name__,
|
404
|
-
compile_cfgs=[enc_compile_config, dec_compile_config],
|
405
|
-
rbln_kwargs=rbln_kwargs,
|
406
|
-
)
|
407
|
-
|
408
|
-
rbln_config.model_cfg.update(
|
409
|
-
{
|
410
|
-
"enc_max_seq_len": rbln_enc_max_seq_len,
|
411
|
-
"dec_max_seq_len": rbln_dec_max_seq_len,
|
412
|
-
"batch_size": rbln_batch_size,
|
413
|
-
"pad_token_id": rbln_pad_token_id,
|
414
|
-
}
|
415
|
-
)
|
416
|
-
|
417
|
-
return rbln_config
|
@@ -18,7 +18,7 @@ import torch
|
|
18
18
|
from torch import nn
|
19
19
|
from transformers.utils import logging
|
20
20
|
|
21
|
-
from ....ops import
|
21
|
+
from ....ops import register_rbln_custom_paged_add_softmax_attention
|
22
22
|
from ..seq2seq.seq2seq_architecture import (
|
23
23
|
Seq2SeqDecoder,
|
24
24
|
Seq2SeqDecoderLayer,
|
@@ -55,7 +55,7 @@ class T5EncoderWrapper(Seq2SeqEncoderWrapper):
|
|
55
55
|
|
56
56
|
class T5DecoderWrapper(Seq2SeqDecoderWrapper):
|
57
57
|
def __post_init__(self, model, dec_max_seq_len: int = None):
|
58
|
-
|
58
|
+
register_rbln_custom_paged_add_softmax_attention()
|
59
59
|
self.num_layers = self.config.num_layers
|
60
60
|
self.conditional_generation = self.convert_to_rbln_conditional_generation(model, dec_max_seq_len)
|
61
61
|
|
@@ -77,6 +77,7 @@ class T5DecoderWrapper(Seq2SeqDecoderWrapper):
|
|
77
77
|
attention_mask,
|
78
78
|
encoder_attention_mask,
|
79
79
|
cache_position,
|
80
|
+
block_tables,
|
80
81
|
cross_kv_cache,
|
81
82
|
*self_kv_cache,
|
82
83
|
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
|
@@ -95,6 +96,7 @@ class T5DecoderWrapper(Seq2SeqDecoderWrapper):
|
|
95
96
|
self_past_key_values=self_past_key_values,
|
96
97
|
cross_past_key_values=cross_past_key_values,
|
97
98
|
cache_position=cache_position,
|
99
|
+
block_tables=block_tables,
|
98
100
|
)
|
99
101
|
|
100
102
|
return lm_logits
|
@@ -162,7 +164,7 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
|
|
162
164
|
self.out_proj = self._original_mod.o
|
163
165
|
self.num_heads = self._original_mod.n_heads
|
164
166
|
self.head_dim = self._original_mod.key_value_proj_dim
|
165
|
-
self.attn_decode = torch.ops.rbln_custom_ops.
|
167
|
+
self.attn_decode = torch.ops.rbln_custom_ops.paged_add_softmax_attn_decode
|
166
168
|
|
167
169
|
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
168
170
|
query_states = self.q_proj(hidden_states)
|
@@ -176,6 +178,7 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
|
|
176
178
|
past_key_value: Tuple[torch.Tensor],
|
177
179
|
attention_mask: torch.Tensor,
|
178
180
|
cache_position: torch.Tensor,
|
181
|
+
block_tables: torch.Tensor,
|
179
182
|
**kwargs,
|
180
183
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
181
184
|
bsz, tgt_len, _ = hidden_states.size()
|
@@ -185,6 +188,7 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
|
|
185
188
|
key_states = self._shape(key_states, -1, bsz)
|
186
189
|
value_states = self._shape(value_states, -1, bsz)
|
187
190
|
|
191
|
+
block_size = past_key_value[0].shape[-2]
|
188
192
|
attn_output = self.attn_decode(
|
189
193
|
query_states,
|
190
194
|
key_states,
|
@@ -196,6 +200,8 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
|
|
196
200
|
past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
|
197
201
|
cache_position,
|
198
202
|
torch.tensor(1.0, dtype=torch.float32), # scale
|
203
|
+
block_tables,
|
204
|
+
block_size,
|
199
205
|
)
|
200
206
|
|
201
207
|
attn_output = attn_output.view(bsz, self.num_heads, -1, self.head_dim).transpose(1, 2)
|
@@ -61,6 +61,16 @@ class RBLNRuntimeEncoder(RBLNPytorchRuntime):
|
|
61
61
|
class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
62
62
|
mandatory_members = ["main_input_name"]
|
63
63
|
|
64
|
+
def __init__(
|
65
|
+
self,
|
66
|
+
runtime: rebel.Runtime,
|
67
|
+
batch_size: int,
|
68
|
+
**kwargs: Any,
|
69
|
+
) -> None:
|
70
|
+
super().__init__(runtime, **kwargs)
|
71
|
+
self.batch_size = batch_size
|
72
|
+
self.default_block_tables = torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, 1)
|
73
|
+
|
64
74
|
def forward(
|
65
75
|
self,
|
66
76
|
decoder_input_ids: torch.Tensor = None,
|
@@ -76,6 +86,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
76
86
|
decoder_input_ids=decoder_input_ids,
|
77
87
|
decoder_attention_mask=decoder_attention_mask,
|
78
88
|
cache_position=cache_position,
|
89
|
+
block_tables=self.default_block_tables,
|
79
90
|
)
|
80
91
|
|
81
92
|
if isinstance(outputs, torch.Tensor):
|
@@ -237,6 +248,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
237
248
|
("decoder_input_ids", [rbln_batch_size, 1], "int64"),
|
238
249
|
("decoder_attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "int64"),
|
239
250
|
("cache_position", [], "int32"),
|
251
|
+
("block_tables", [rbln_batch_size, 1], "int16"),
|
240
252
|
]
|
241
253
|
dec_input_info.extend(
|
242
254
|
[
|
@@ -25,7 +25,7 @@ from transformers.modeling_outputs import (
|
|
25
25
|
)
|
26
26
|
from transformers.utils import logging
|
27
27
|
|
28
|
-
from ....ops import
|
28
|
+
from ....ops import register_rbln_custom_cache_update, register_rbln_custom_paged_add_softmax_attention
|
29
29
|
|
30
30
|
|
31
31
|
logger = logging.get_logger(__name__)
|
@@ -34,7 +34,7 @@ logger = logging.get_logger(__name__)
|
|
34
34
|
class WhisperWrapper:
|
35
35
|
def __init__(self, model, rbln_token_timestamps):
|
36
36
|
register_rbln_custom_cache_update()
|
37
|
-
|
37
|
+
register_rbln_custom_paged_add_softmax_attention()
|
38
38
|
self.encoder = WhisperEncoderWrapper(model)
|
39
39
|
self.decoder = WhisperDecoderWrapper(model, output_attentions=rbln_token_timestamps)
|
40
40
|
|
@@ -108,6 +108,7 @@ class WhisperDecoderWrapper(torch.nn.Module):
|
|
108
108
|
decoder_input_ids: torch.Tensor,
|
109
109
|
decoder_attention_mask: torch.Tensor,
|
110
110
|
cache_position: torch.Tensor,
|
111
|
+
block_tables: torch.Tensor,
|
111
112
|
cross_kv_cache: torch.Tensor,
|
112
113
|
*self_kv_cache: torch.Tensor,
|
113
114
|
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
@@ -125,6 +126,7 @@ class WhisperDecoderWrapper(torch.nn.Module):
|
|
125
126
|
cache_position=cache_position,
|
126
127
|
self_past_key_values=self_past_key_values,
|
127
128
|
cross_past_key_values=cross_past_key_values,
|
129
|
+
block_tables=block_tables,
|
128
130
|
)
|
129
131
|
|
130
132
|
lm_logits = self.proj_out(sequence_output)
|
@@ -154,6 +156,7 @@ class WhisperDecoder(nn.Module):
|
|
154
156
|
self_past_key_values: Optional[torch.Tensor] = None,
|
155
157
|
cross_past_key_values: Optional[torch.Tensor] = None,
|
156
158
|
cache_position: Optional[torch.Tensor] = None,
|
159
|
+
block_tables: Optional[torch.Tensor] = None,
|
157
160
|
):
|
158
161
|
input_shape = input_ids.size()
|
159
162
|
input_ids = input_ids.view(-1, input_shape[-1])
|
@@ -177,6 +180,7 @@ class WhisperDecoder(nn.Module):
|
|
177
180
|
self_past_key_value=self_past_key_value,
|
178
181
|
cross_past_key_value=cross_past_key_value,
|
179
182
|
cache_position=cache_position,
|
183
|
+
block_tables=block_tables,
|
180
184
|
)
|
181
185
|
cross_attentions += (cross_attn_weights,)
|
182
186
|
|
@@ -205,6 +209,7 @@ class WhisperDecoderLayer(nn.Module):
|
|
205
209
|
self_past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
206
210
|
cross_past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
207
211
|
cache_position: Optional[torch.Tensor] = None,
|
212
|
+
block_tables: Optional[torch.Tensor] = None,
|
208
213
|
) -> torch.Tensor:
|
209
214
|
# Self Attention Block
|
210
215
|
residual = hidden_states
|
@@ -214,6 +219,7 @@ class WhisperDecoderLayer(nn.Module):
|
|
214
219
|
past_key_value=self_past_key_value,
|
215
220
|
attention_mask=attention_mask,
|
216
221
|
cache_position=cache_position,
|
222
|
+
block_tables=block_tables,
|
217
223
|
)
|
218
224
|
hidden_states = residual + hidden_states
|
219
225
|
|
@@ -263,6 +269,7 @@ class WhisperSelfAttention(WhisperAttention):
|
|
263
269
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
264
270
|
attention_mask: Optional[torch.Tensor] = None,
|
265
271
|
cache_position: Optional[torch.Tensor] = None,
|
272
|
+
block_tables: Optional[torch.Tensor] = None,
|
266
273
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
267
274
|
bsz, tgt_len, _ = hidden_states.size()
|
268
275
|
query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
|
@@ -270,8 +277,9 @@ class WhisperSelfAttention(WhisperAttention):
|
|
270
277
|
|
271
278
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
272
279
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
280
|
+
block_size = past_key_value[0].shape[-2]
|
273
281
|
|
274
|
-
attn_output = torch.ops.rbln_custom_ops.
|
282
|
+
attn_output = torch.ops.rbln_custom_ops.paged_add_softmax_attn_decode(
|
275
283
|
query_states,
|
276
284
|
key_states,
|
277
285
|
value_states,
|
@@ -280,6 +288,8 @@ class WhisperSelfAttention(WhisperAttention):
|
|
280
288
|
past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
|
281
289
|
cache_position.expand(bsz, 1),
|
282
290
|
torch.tensor(1.0, dtype=torch.float32), # scale
|
291
|
+
block_tables,
|
292
|
+
block_size,
|
283
293
|
)
|
284
294
|
|
285
295
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: optimum-rbln
|
3
|
-
Version: 0.7.
|
3
|
+
Version: 0.7.4a0
|
4
4
|
Summary: Optimum RBLN is the interface between the Hugging Face Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
|
5
5
|
Project-URL: Homepage, https://rebellions.ai
|
6
6
|
Project-URL: Documentation, https://docs.rbln.ai
|
@@ -1,5 +1,5 @@
|
|
1
1
|
optimum/rbln/__init__.py,sha256=ZDzXcl-oAcYJhKjJMpotjbTih9awo7HzUb6T3MUEP6Q,6894
|
2
|
-
optimum/rbln/__version__.py,sha256=
|
2
|
+
optimum/rbln/__version__.py,sha256=xyj1Oj5eR1yz0oBU9FRdubMKrBiNrPrrW8h8ohd1iG8,513
|
3
3
|
optimum/rbln/modeling.py,sha256=nJsAs5zs--VVOYGFjYNpqfxYIemJIK4Lr0WEzlDLdP0,8390
|
4
4
|
optimum/rbln/modeling_base.py,sha256=dNCL-BhrWCpuOVkZaj8-MW567Tf4lLo3p3Z3ldjWJfU,21779
|
5
5
|
optimum/rbln/modeling_config.py,sha256=7104bxmrvKW4Q6XTruQayiIGl8GHDFmPkJ3cknMIInE,11335
|
@@ -41,8 +41,8 @@ optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py,sha256=9iIMZYvp
|
|
41
41
|
optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py,sha256=OvB5bxX6HUiqJeIc3uukuEmUXYEx1pTqGNOtdG2l1m8,902
|
42
42
|
optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py,sha256=3aB1Rw-OgKytQOHwOaShbEvq_XVHPOGvsGm8pstEmKU,930
|
43
43
|
optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py,sha256=MzVP1wscaO1sUIiBIPJqG6zuGyez9VUbA42-JSIm-mk,930
|
44
|
-
optimum/rbln/ops/__init__.py,sha256=
|
45
|
-
optimum/rbln/ops/attn.py,sha256=
|
44
|
+
optimum/rbln/ops/__init__.py,sha256=Wv2cJhEw8mqc6-To24bHzf4qQL8gM0Zh_2Ck77LB65g,947
|
45
|
+
optimum/rbln/ops/attn.py,sha256=OSgPoEgCwvR7HdjbnaVkFVMBcJ5RpRWcE6OCg2lVyGk,10634
|
46
46
|
optimum/rbln/ops/flash_attn.py,sha256=wfyiCxDGf034IngzwRU160R7_DlKYpd-uWT0BDEGFks,3408
|
47
47
|
optimum/rbln/ops/kv_cache_update.py,sha256=pxf8kAptPaQF5xE8qItvmlFOq_sgim6ZERD7AVaOtec,3221
|
48
48
|
optimum/rbln/transformers/__init__.py,sha256=AGo3BqVIZrsOzYsQAnnQ25HCstTPBclrXbvvUxVMlqE,4255
|
@@ -55,7 +55,7 @@ optimum/rbln/transformers/models/auto/auto_factory.py,sha256=IK9jFrJ3EEzYQa9_aKp
|
|
55
55
|
optimum/rbln/transformers/models/auto/modeling_auto.py,sha256=Un9qoqdy3dO8JBza_bTJF_6_fRVNM9QisihSgTRFI-o,3933
|
56
56
|
optimum/rbln/transformers/models/bart/__init__.py,sha256=32HPe0_GIO0hp9U464Iv6Jd7M-1nop9g8hA1UZMHhyw,674
|
57
57
|
optimum/rbln/transformers/models/bart/bart_architecture.py,sha256=Oo-Cdne7igKEex8wwP-gztKJHgs5GLHQjK1oc3IZIDE,5801
|
58
|
-
optimum/rbln/transformers/models/bart/modeling_bart.py,sha256=
|
58
|
+
optimum/rbln/transformers/models/bart/modeling_bart.py,sha256=CUF5PE9TxJxtO1VpuGgeKrL_u6PdsKxstlZDthYSXgU,5829
|
59
59
|
optimum/rbln/transformers/models/bert/__init__.py,sha256=YVV7k_laU6yJBawZrgjIWjRmIF-Y4oQQHqyf8lsraQs,691
|
60
60
|
optimum/rbln/transformers/models/bert/modeling_bert.py,sha256=p3utRqf3dv9_RkHwaMCa1EfXttNJkqCJUIZo3CeZ9YY,4674
|
61
61
|
optimum/rbln/transformers/models/clip/__init__.py,sha256=H9vuBwrmFO0-CqZhXUrKF-uQL6igCqMlqrT1X_ELaAI,754
|
@@ -92,17 +92,17 @@ optimum/rbln/transformers/models/qwen2/__init__.py,sha256=RAMWc21W_2I6DH9xBjeNxP
|
|
92
92
|
optimum/rbln/transformers/models/qwen2/modeling_qwen2.py,sha256=9-aFDvjMzPNUyGOz0qo33RE18bUFGYZ3Wt_68zb5uJY,1530
|
93
93
|
optimum/rbln/transformers/models/qwen2/qwen2_architecture.py,sha256=XlNAMYAcDLohnSAhIFGKOPuCB5XLgzYs5ABWdeQSaZs,720
|
94
94
|
optimum/rbln/transformers/models/seq2seq/__init__.py,sha256=EmEMV4rOYqKyruX85d0fR73-b8N6BSD6CPcbpYdBuVk,651
|
95
|
-
optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py,sha256=
|
95
|
+
optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py,sha256=QelhuCWEHPL2Ut7fm0gLnzTVveBAaKSNpoa9X1AmwTI,17709
|
96
96
|
optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py,sha256=tvzacIZam1sIr_1BvvZ_fDr8u5dXAiYiynFdX9tArtY,18877
|
97
97
|
optimum/rbln/transformers/models/t5/__init__.py,sha256=1skR1RmnG62WTAP3-F5P1x-V_ReFhMyirH3u56vWwvc,675
|
98
|
-
optimum/rbln/transformers/models/t5/modeling_t5.py,sha256
|
99
|
-
optimum/rbln/transformers/models/t5/t5_architecture.py,sha256=
|
98
|
+
optimum/rbln/transformers/models/t5/modeling_t5.py,sha256=-fG-h0wwsfjZ3par0QHbXKA7hbvw_lPJOIf8iXQDOfM,8082
|
99
|
+
optimum/rbln/transformers/models/t5/t5_architecture.py,sha256=Ups6drBbYe4wEAiBLcBIyO9wqrIQbvOPFR_ybbAgR8c,9722
|
100
100
|
optimum/rbln/transformers/models/wav2vec2/__init__.py,sha256=YpgA0K-vyg9veh0eL_jxauosbRpb_kpGKHvvQLBspKM,649
|
101
101
|
optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py,sha256=JYJmV52j6cBwim4RanVJryfKnV80V96ol0A-oR6o7cg,3856
|
102
102
|
optimum/rbln/transformers/models/whisper/__init__.py,sha256=ktnNe5ri3ycCWZ_W_voFB9y9-vgGgxS1X9s8LBRZmWc,665
|
103
103
|
optimum/rbln/transformers/models/whisper/generation_whisper.py,sha256=GIHTca3b1VtW81kp7BzKQ7f77c2t9OsEsbZetripgDo,4582
|
104
|
-
optimum/rbln/transformers/models/whisper/modeling_whisper.py,sha256=
|
105
|
-
optimum/rbln/transformers/models/whisper/whisper_architecture.py,sha256=
|
104
|
+
optimum/rbln/transformers/models/whisper/modeling_whisper.py,sha256=U9zK49DcSdXuoK_UOsVPsyKe6EJ5CQR8QZhpgi23EUU,16275
|
105
|
+
optimum/rbln/transformers/models/whisper/whisper_architecture.py,sha256=ArQPOgiRVu-XddEN5FXVl1OlCoGF6uY7jGoWTj3Nfe4,13005
|
106
106
|
optimum/rbln/transformers/models/xlm_roberta/__init__.py,sha256=fC7iNcdxBZ_6eOF2snStmf8r2M3c8O_-XcXnQEaHQCE,653
|
107
107
|
optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py,sha256=8YNLz0bc5ze-QuU8rN-QhUfGzlSUs3iMJiWTxO3o6AM,4366
|
108
108
|
optimum/rbln/transformers/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -116,7 +116,7 @@ optimum/rbln/utils/model_utils.py,sha256=DfD_Z2qvZHqcddXqnzTM1AN8khanj3-DXK2lJvV
|
|
116
116
|
optimum/rbln/utils/runtime_utils.py,sha256=5-DYniyP59nx-mrrbi7AqA77L85b4Cm5oLpaxidSyss,3699
|
117
117
|
optimum/rbln/utils/save_utils.py,sha256=hG5uOtYmecSXZuGTvCXsTM-SiyZpr5q3InUGCCq_jzQ,3619
|
118
118
|
optimum/rbln/utils/submodule.py,sha256=oZoGrItB8WqY4i-K9WJPlLlcLohc1YGB9OHB8_XZw3A,4071
|
119
|
-
optimum_rbln-0.7.
|
120
|
-
optimum_rbln-0.7.
|
121
|
-
optimum_rbln-0.7.
|
122
|
-
optimum_rbln-0.7.
|
119
|
+
optimum_rbln-0.7.4a0.dist-info/METADATA,sha256=tXU0EmgjFJug_Cvmw8S9NeEZ2z9XpgamFwgMQTTCa1U,5300
|
120
|
+
optimum_rbln-0.7.4a0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
121
|
+
optimum_rbln-0.7.4a0.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
|
122
|
+
optimum_rbln-0.7.4a0.dist-info/RECORD,,
|
File without changes
|
File without changes
|