optimum-rbln 0.7.3.post2__py3-none-any.whl → 0.7.4a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.7.3.post2'
21
- __version_tuple__ = version_tuple = (0, 7, 3)
20
+ __version__ = version = '0.7.4a0'
21
+ __version_tuple__ = version_tuple = (0, 7, 4)
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from .attn import (
16
- register_rbln_custom_add_softmax_attention,
16
+ register_rbln_custom_paged_add_softmax_attention,
17
17
  register_rbln_custom_paged_attention,
18
18
  register_rbln_custom_paged_causal_attention,
19
19
  )
optimum/rbln/ops/attn.py CHANGED
@@ -182,14 +182,14 @@ def register_rbln_custom_paged_causal_attention():
182
182
 
183
183
 
184
184
  @lru_cache
185
- def register_rbln_custom_add_softmax_attention():
185
+ def register_rbln_custom_paged_add_softmax_attention():
186
186
  torch.library.define(
187
- "rbln_custom_ops::add_softmax_attn_decode",
188
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor",
187
+ "rbln_custom_ops::paged_add_softmax_attn_decode",
188
+ "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor",
189
189
  )
190
190
 
191
- @torch.library.impl("rbln_custom_ops::add_softmax_attn_decode", "cpu")
192
- def add_softmax_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale):
191
+ @torch.library.impl("rbln_custom_ops::paged_add_softmax_attn_decode", "cpu")
192
+ def paged_add_softmax_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size):
193
193
  """Defines the computation pattern for fused attention with KV cache updates.
194
194
 
195
195
  IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
@@ -210,12 +210,14 @@ def register_rbln_custom_add_softmax_attention():
210
210
  - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
211
211
  - seq: [1] - Current sequence position
212
212
  - scale: [] - Attention scale factor
213
+ - block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
214
+ - block_size: [] - Number of tokens per block
213
215
 
214
216
  Returns:
215
217
  Tensor: attn_output: [batch=1, n_heads, 1, 1, head_dim] - Attention output
216
218
  """
217
219
  return q
218
220
 
219
- @register_fake("rbln_custom_ops::add_softmax_attn_decode")
220
- def add_softmax_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, partition):
221
+ @register_fake("rbln_custom_ops::paged_add_softmax_attn_decode")
222
+ def paged_add_softmax_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, partition, block_table, block_size):
221
223
  return q
@@ -108,6 +108,8 @@ class RBLNBartModel(RBLNModel):
108
108
 
109
109
 
110
110
  class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
111
+ support_causal_attn = True
112
+
111
113
  @classmethod
112
114
  def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
113
115
  enc_max_seq_len = (
@@ -94,7 +94,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
94
94
  decoder_attention_mask if self.use_attention_mask else None,
95
95
  attention_mask,
96
96
  cache_position,
97
- block_tables,
97
+ block_tables=block_tables,
98
98
  )
99
99
 
100
100
  return Seq2SeqLMOutput(logits=lm_logits)
@@ -115,6 +115,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
115
115
 
116
116
  main_input_name = "input_ids"
117
117
  auto_model_class = AutoModelForSeq2SeqLM
118
+ support_causal_attn = None
118
119
 
119
120
  def __post_init__(self, **kwargs):
120
121
  batch_size = self.rbln_config.model_cfg["batch_size"]
@@ -186,13 +187,16 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
186
187
  rbln_dec_max_seq_len = rbln_kwargs.get("dec_max_seq_len", None)
187
188
  rbln_batch_size = rbln_kwargs.get("batch_size", None)
188
189
  rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
189
- rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
190
190
 
191
- if rbln_use_attention_mask is None:
192
- rbln_use_attention_mask = False
193
- rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
194
- if rbln_npu == "RBLN-CA02":
195
- rbln_use_attention_mask = True
191
+ if cls.support_causal_attn:
192
+ rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
193
+ if rbln_use_attention_mask is None:
194
+ rbln_use_attention_mask = False
195
+ rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
196
+ if rbln_npu == "RBLN-CA02":
197
+ rbln_use_attention_mask = True
198
+ else:
199
+ rbln_use_attention_mask = True
196
200
 
197
201
  n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
198
202
  n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
@@ -265,11 +269,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
265
269
  [rbln_batch_size, 1],
266
270
  "int32",
267
271
  ),
268
- (
269
- "block_tables",
270
- [rbln_batch_size, 1],
271
- "int16",
272
- ),
272
+ ("block_tables", [rbln_batch_size, 1], "int16"),
273
273
  ]
274
274
  dec_input_info.extend(
275
275
  [
@@ -13,9 +13,8 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
16
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
17
17
 
18
- import rebel
19
18
  import torch
20
19
  from transformers import (
21
20
  AutoModelForTextEncoding,
@@ -23,7 +22,7 @@ from transformers import (
23
22
  T5EncoderModel,
24
23
  T5ForConditionalGeneration,
25
24
  )
26
- from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
25
+ from transformers.modeling_outputs import BaseModelOutput
27
26
 
28
27
  from ....diffusers.modeling_diffusers import RBLNDiffusionMixin
29
28
  from ....modeling import RBLNModel
@@ -58,63 +57,6 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
58
57
  )
59
58
 
60
59
 
61
- class RBLNRuntimeEncoder(RBLNPytorchRuntime):
62
- mandatory_members = ["main_input_name"]
63
-
64
- def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
65
- _ = super().forward(*args, **kwargs)
66
- return BaseModelOutput(last_hidden_state=torch.tensor([1.0]))
67
-
68
-
69
- class RBLNRuntimeDecoder(RBLNPytorchRuntime):
70
- mandatory_members = ["main_input_name"]
71
-
72
- def __init__(
73
- self,
74
- runtime: rebel.Runtime,
75
- batch_size: int,
76
- dec_max_seq_len: int,
77
- **kwargs: Any,
78
- ) -> None:
79
- super().__init__(runtime, **kwargs)
80
- self.batch_size = batch_size
81
- self.dec_max_seq_len = dec_max_seq_len
82
-
83
- def forward(
84
- self,
85
- decoder_input_ids: Optional[torch.LongTensor] = None,
86
- attention_mask: Optional[torch.FloatTensor] = None,
87
- decoder_attention_mask: Optional[torch.BoolTensor] = None,
88
- cache_position: Optional[torch.Tensor] = None,
89
- **kwargs,
90
- ) -> Tuple[torch.FloatTensor]:
91
- batch_size = decoder_input_ids.shape[0]
92
- if batch_size != self.batch_size:
93
- raise RuntimeError(
94
- f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
95
- )
96
-
97
- if batch_size != cache_position.shape[0]:
98
- raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
99
-
100
- for b_idx in range(self.batch_size):
101
- decoding_step = cache_position[b_idx].item()
102
- if not (0 <= decoding_step < self.dec_max_seq_len):
103
- raise ValueError(
104
- f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
105
- )
106
- decoder_attention_mask[b_idx, : decoding_step + 1] = 1
107
-
108
- lm_logits = super().forward(
109
- decoder_input_ids,
110
- decoder_attention_mask,
111
- attention_mask,
112
- cache_position,
113
- )
114
-
115
- return Seq2SeqLMOutput(logits=lm_logits)
116
-
117
-
118
60
  class T5EncoderWrapper(torch.nn.Module):
119
61
  def __init__(self, model: "T5EncoderModel") -> None:
120
62
  super().__init__()
@@ -247,20 +189,7 @@ class RBLNT5EncoderModel(RBLNModel):
247
189
 
248
190
 
249
191
  class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
250
- def __post_init__(self, **kwargs):
251
- batch_size = self.rbln_config.model_cfg["batch_size"]
252
- dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
253
-
254
- self.encoder = RBLNRuntimeEncoder(
255
- runtime=self.model[0],
256
- main_input_name="input_ids",
257
- )
258
- self.decoder = RBLNRuntimeDecoder(
259
- runtime=self.model[1],
260
- main_input_name="input_ids",
261
- batch_size=batch_size,
262
- dec_max_seq_len=dec_max_seq_len,
263
- )
192
+ support_causal_attn = False
264
193
 
265
194
  @classmethod
266
195
  def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
@@ -279,139 +208,3 @@ class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
279
208
  return redirect(val)
280
209
 
281
210
  return val
282
-
283
- @classmethod
284
- def _get_rbln_config(
285
- cls,
286
- preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
287
- model_config: "PretrainedConfig",
288
- rbln_kwargs: Dict[str, Any] = {},
289
- ) -> RBLNConfig:
290
- rbln_enc_max_seq_len = rbln_kwargs.get("enc_max_seq_len", None)
291
- rbln_dec_max_seq_len = rbln_kwargs.get("dec_max_seq_len", None)
292
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
293
- rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
294
-
295
- n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
296
- n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
297
- d_kv = (
298
- model_config.d_kv
299
- if hasattr(model_config, "d_kv")
300
- else model_config.d_model // model_config.encoder_attention_heads
301
- )
302
-
303
- max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
304
- model_config, "max_position_embeddings", None
305
- )
306
-
307
- rbln_pad_token_id = getattr(model_config, "pad_token_id", None)
308
- if rbln_pad_token_id is None:
309
- rbln_pad_token_id = getattr(model_config, "bos_token_id", None)
310
- if rbln_pad_token_id is None:
311
- rbln_pad_token_id = getattr(model_config, "eos_token_id", None)
312
- if rbln_pad_token_id is None:
313
- rbln_pad_token_id = -1
314
-
315
- if rbln_enc_max_seq_len is None:
316
- rbln_enc_max_seq_len = max_position_embeddings
317
- if rbln_enc_max_seq_len is None:
318
- for tokenizer in preprocessors:
319
- if hasattr(tokenizer, "model_max_length"):
320
- rbln_enc_max_seq_len = tokenizer.model_max_length
321
- break
322
- if rbln_enc_max_seq_len is None:
323
- raise ValueError("`rbln_enc_max_seq_len` should be specified!")
324
- if max_position_embeddings is not None and rbln_enc_max_seq_len > max_position_embeddings:
325
- raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
326
-
327
- if rbln_dec_max_seq_len is None:
328
- rbln_dec_max_seq_len = max_position_embeddings
329
- if rbln_dec_max_seq_len is None:
330
- for tokenizer in preprocessors:
331
- if hasattr(tokenizer, "model_max_length"):
332
- rbln_dec_max_seq_len = tokenizer.model_max_length
333
- break
334
- if rbln_dec_max_seq_len is None:
335
- raise ValueError("`rbln_dec_max_seq_len` should be specified!")
336
-
337
- if max_position_embeddings is not None and rbln_dec_max_seq_len > max_position_embeddings:
338
- raise ValueError("`rbln_dec_max_seq_len` should be less or equal than max_position_embeddings!")
339
-
340
- # model input info
341
- enc_input_info = [
342
- ("input_ids", [1, rbln_enc_max_seq_len], "int64"),
343
- ("attention_mask", [1, rbln_enc_max_seq_len], "float32"),
344
- (
345
- "cross_key_value_states",
346
- [
347
- n_layer * 2,
348
- rbln_batch_size,
349
- n_head,
350
- rbln_enc_max_seq_len,
351
- d_kv,
352
- ],
353
- "float32",
354
- ),
355
- ("block_tables", [1], "int16"),
356
- ]
357
-
358
- dec_input_info = [
359
- ("input_ids", [rbln_batch_size, 1], "int64"),
360
- ("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"),
361
- ("encoder_attention_mask", [rbln_batch_size, rbln_enc_max_seq_len], "float32"),
362
- (
363
- "cache_position",
364
- [rbln_batch_size, 1],
365
- "int32",
366
- ),
367
- ]
368
- dec_input_info.extend(
369
- [
370
- (
371
- "cross_key_value_states",
372
- [
373
- n_layer * 2,
374
- rbln_batch_size,
375
- n_head,
376
- rbln_enc_max_seq_len,
377
- d_kv,
378
- ],
379
- "float32",
380
- )
381
- ]
382
- )
383
- dec_input_info.extend(
384
- [
385
- (
386
- f"self_key_value_states_{i}",
387
- [
388
- rbln_batch_size,
389
- n_head,
390
- rbln_dec_max_seq_len,
391
- d_kv,
392
- ],
393
- "float32",
394
- )
395
- for i in range(n_layer * 2)
396
- ]
397
- )
398
-
399
- enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
400
- dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
401
-
402
- rbln_config = RBLNConfig(
403
- rbln_cls=cls.__name__,
404
- compile_cfgs=[enc_compile_config, dec_compile_config],
405
- rbln_kwargs=rbln_kwargs,
406
- )
407
-
408
- rbln_config.model_cfg.update(
409
- {
410
- "enc_max_seq_len": rbln_enc_max_seq_len,
411
- "dec_max_seq_len": rbln_dec_max_seq_len,
412
- "batch_size": rbln_batch_size,
413
- "pad_token_id": rbln_pad_token_id,
414
- }
415
- )
416
-
417
- return rbln_config
@@ -18,7 +18,7 @@ import torch
18
18
  from torch import nn
19
19
  from transformers.utils import logging
20
20
 
21
- from ....ops import register_rbln_custom_add_softmax_attention
21
+ from ....ops import register_rbln_custom_paged_add_softmax_attention
22
22
  from ..seq2seq.seq2seq_architecture import (
23
23
  Seq2SeqDecoder,
24
24
  Seq2SeqDecoderLayer,
@@ -55,7 +55,7 @@ class T5EncoderWrapper(Seq2SeqEncoderWrapper):
55
55
 
56
56
  class T5DecoderWrapper(Seq2SeqDecoderWrapper):
57
57
  def __post_init__(self, model, dec_max_seq_len: int = None):
58
- register_rbln_custom_add_softmax_attention()
58
+ register_rbln_custom_paged_add_softmax_attention()
59
59
  self.num_layers = self.config.num_layers
60
60
  self.conditional_generation = self.convert_to_rbln_conditional_generation(model, dec_max_seq_len)
61
61
 
@@ -77,6 +77,7 @@ class T5DecoderWrapper(Seq2SeqDecoderWrapper):
77
77
  attention_mask,
78
78
  encoder_attention_mask,
79
79
  cache_position,
80
+ block_tables,
80
81
  cross_kv_cache,
81
82
  *self_kv_cache,
82
83
  ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
@@ -95,6 +96,7 @@ class T5DecoderWrapper(Seq2SeqDecoderWrapper):
95
96
  self_past_key_values=self_past_key_values,
96
97
  cross_past_key_values=cross_past_key_values,
97
98
  cache_position=cache_position,
99
+ block_tables=block_tables,
98
100
  )
99
101
 
100
102
  return lm_logits
@@ -162,7 +164,7 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
162
164
  self.out_proj = self._original_mod.o
163
165
  self.num_heads = self._original_mod.n_heads
164
166
  self.head_dim = self._original_mod.key_value_proj_dim
165
- self.attn_decode = torch.ops.rbln_custom_ops.add_softmax_attn_decode
167
+ self.attn_decode = torch.ops.rbln_custom_ops.paged_add_softmax_attn_decode
166
168
 
167
169
  def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
168
170
  query_states = self.q_proj(hidden_states)
@@ -176,6 +178,7 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
176
178
  past_key_value: Tuple[torch.Tensor],
177
179
  attention_mask: torch.Tensor,
178
180
  cache_position: torch.Tensor,
181
+ block_tables: torch.Tensor,
179
182
  **kwargs,
180
183
  ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
181
184
  bsz, tgt_len, _ = hidden_states.size()
@@ -185,6 +188,7 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
185
188
  key_states = self._shape(key_states, -1, bsz)
186
189
  value_states = self._shape(value_states, -1, bsz)
187
190
 
191
+ block_size = past_key_value[0].shape[-2]
188
192
  attn_output = self.attn_decode(
189
193
  query_states,
190
194
  key_states,
@@ -196,6 +200,8 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
196
200
  past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
197
201
  cache_position,
198
202
  torch.tensor(1.0, dtype=torch.float32), # scale
203
+ block_tables,
204
+ block_size,
199
205
  )
200
206
 
201
207
  attn_output = attn_output.view(bsz, self.num_heads, -1, self.head_dim).transpose(1, 2)
@@ -61,6 +61,16 @@ class RBLNRuntimeEncoder(RBLNPytorchRuntime):
61
61
  class RBLNRuntimeDecoder(RBLNPytorchRuntime):
62
62
  mandatory_members = ["main_input_name"]
63
63
 
64
+ def __init__(
65
+ self,
66
+ runtime: rebel.Runtime,
67
+ batch_size: int,
68
+ **kwargs: Any,
69
+ ) -> None:
70
+ super().__init__(runtime, **kwargs)
71
+ self.batch_size = batch_size
72
+ self.default_block_tables = torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, 1)
73
+
64
74
  def forward(
65
75
  self,
66
76
  decoder_input_ids: torch.Tensor = None,
@@ -76,6 +86,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
76
86
  decoder_input_ids=decoder_input_ids,
77
87
  decoder_attention_mask=decoder_attention_mask,
78
88
  cache_position=cache_position,
89
+ block_tables=self.default_block_tables,
79
90
  )
80
91
 
81
92
  if isinstance(outputs, torch.Tensor):
@@ -237,6 +248,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
237
248
  ("decoder_input_ids", [rbln_batch_size, 1], "int64"),
238
249
  ("decoder_attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "int64"),
239
250
  ("cache_position", [], "int32"),
251
+ ("block_tables", [rbln_batch_size, 1], "int16"),
240
252
  ]
241
253
  dec_input_info.extend(
242
254
  [
@@ -25,7 +25,7 @@ from transformers.modeling_outputs import (
25
25
  )
26
26
  from transformers.utils import logging
27
27
 
28
- from ....ops import register_rbln_custom_add_softmax_attention, register_rbln_custom_cache_update
28
+ from ....ops import register_rbln_custom_cache_update, register_rbln_custom_paged_add_softmax_attention
29
29
 
30
30
 
31
31
  logger = logging.get_logger(__name__)
@@ -34,7 +34,7 @@ logger = logging.get_logger(__name__)
34
34
  class WhisperWrapper:
35
35
  def __init__(self, model, rbln_token_timestamps):
36
36
  register_rbln_custom_cache_update()
37
- register_rbln_custom_add_softmax_attention()
37
+ register_rbln_custom_paged_add_softmax_attention()
38
38
  self.encoder = WhisperEncoderWrapper(model)
39
39
  self.decoder = WhisperDecoderWrapper(model, output_attentions=rbln_token_timestamps)
40
40
 
@@ -108,6 +108,7 @@ class WhisperDecoderWrapper(torch.nn.Module):
108
108
  decoder_input_ids: torch.Tensor,
109
109
  decoder_attention_mask: torch.Tensor,
110
110
  cache_position: torch.Tensor,
111
+ block_tables: torch.Tensor,
111
112
  cross_kv_cache: torch.Tensor,
112
113
  *self_kv_cache: torch.Tensor,
113
114
  ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
@@ -125,6 +126,7 @@ class WhisperDecoderWrapper(torch.nn.Module):
125
126
  cache_position=cache_position,
126
127
  self_past_key_values=self_past_key_values,
127
128
  cross_past_key_values=cross_past_key_values,
129
+ block_tables=block_tables,
128
130
  )
129
131
 
130
132
  lm_logits = self.proj_out(sequence_output)
@@ -154,6 +156,7 @@ class WhisperDecoder(nn.Module):
154
156
  self_past_key_values: Optional[torch.Tensor] = None,
155
157
  cross_past_key_values: Optional[torch.Tensor] = None,
156
158
  cache_position: Optional[torch.Tensor] = None,
159
+ block_tables: Optional[torch.Tensor] = None,
157
160
  ):
158
161
  input_shape = input_ids.size()
159
162
  input_ids = input_ids.view(-1, input_shape[-1])
@@ -177,6 +180,7 @@ class WhisperDecoder(nn.Module):
177
180
  self_past_key_value=self_past_key_value,
178
181
  cross_past_key_value=cross_past_key_value,
179
182
  cache_position=cache_position,
183
+ block_tables=block_tables,
180
184
  )
181
185
  cross_attentions += (cross_attn_weights,)
182
186
 
@@ -205,6 +209,7 @@ class WhisperDecoderLayer(nn.Module):
205
209
  self_past_key_value: Optional[Tuple[torch.Tensor]] = None,
206
210
  cross_past_key_value: Optional[Tuple[torch.Tensor]] = None,
207
211
  cache_position: Optional[torch.Tensor] = None,
212
+ block_tables: Optional[torch.Tensor] = None,
208
213
  ) -> torch.Tensor:
209
214
  # Self Attention Block
210
215
  residual = hidden_states
@@ -214,6 +219,7 @@ class WhisperDecoderLayer(nn.Module):
214
219
  past_key_value=self_past_key_value,
215
220
  attention_mask=attention_mask,
216
221
  cache_position=cache_position,
222
+ block_tables=block_tables,
217
223
  )
218
224
  hidden_states = residual + hidden_states
219
225
 
@@ -263,6 +269,7 @@ class WhisperSelfAttention(WhisperAttention):
263
269
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
264
270
  attention_mask: Optional[torch.Tensor] = None,
265
271
  cache_position: Optional[torch.Tensor] = None,
272
+ block_tables: Optional[torch.Tensor] = None,
266
273
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
267
274
  bsz, tgt_len, _ = hidden_states.size()
268
275
  query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
@@ -270,8 +277,9 @@ class WhisperSelfAttention(WhisperAttention):
270
277
 
271
278
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
272
279
  value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
280
+ block_size = past_key_value[0].shape[-2]
273
281
 
274
- attn_output = torch.ops.rbln_custom_ops.add_softmax_attn_decode(
282
+ attn_output = torch.ops.rbln_custom_ops.paged_add_softmax_attn_decode(
275
283
  query_states,
276
284
  key_states,
277
285
  value_states,
@@ -280,6 +288,8 @@ class WhisperSelfAttention(WhisperAttention):
280
288
  past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
281
289
  cache_position.expand(bsz, 1),
282
290
  torch.tensor(1.0, dtype=torch.float32), # scale
291
+ block_tables,
292
+ block_size,
283
293
  )
284
294
 
285
295
  attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: optimum-rbln
3
- Version: 0.7.3.post2
3
+ Version: 0.7.4a0
4
4
  Summary: Optimum RBLN is the interface between the Hugging Face Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
5
5
  Project-URL: Homepage, https://rebellions.ai
6
6
  Project-URL: Documentation, https://docs.rbln.ai
@@ -1,5 +1,5 @@
1
1
  optimum/rbln/__init__.py,sha256=ZDzXcl-oAcYJhKjJMpotjbTih9awo7HzUb6T3MUEP6Q,6894
2
- optimum/rbln/__version__.py,sha256=OJRzB6Y7xaNgH7EkerbumPEoY0Nlzs1_xYhBJvOXTzQ,517
2
+ optimum/rbln/__version__.py,sha256=xyj1Oj5eR1yz0oBU9FRdubMKrBiNrPrrW8h8ohd1iG8,513
3
3
  optimum/rbln/modeling.py,sha256=nJsAs5zs--VVOYGFjYNpqfxYIemJIK4Lr0WEzlDLdP0,8390
4
4
  optimum/rbln/modeling_base.py,sha256=dNCL-BhrWCpuOVkZaj8-MW567Tf4lLo3p3Z3ldjWJfU,21779
5
5
  optimum/rbln/modeling_config.py,sha256=7104bxmrvKW4Q6XTruQayiIGl8GHDFmPkJ3cknMIInE,11335
@@ -41,8 +41,8 @@ optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py,sha256=9iIMZYvp
41
41
  optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py,sha256=OvB5bxX6HUiqJeIc3uukuEmUXYEx1pTqGNOtdG2l1m8,902
42
42
  optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py,sha256=3aB1Rw-OgKytQOHwOaShbEvq_XVHPOGvsGm8pstEmKU,930
43
43
  optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py,sha256=MzVP1wscaO1sUIiBIPJqG6zuGyez9VUbA42-JSIm-mk,930
44
- optimum/rbln/ops/__init__.py,sha256=TxOmsN0u3PmyK4Sb89qbiC4rePOlkvUT7Lm6wVoTnY0,941
45
- optimum/rbln/ops/attn.py,sha256=3EqU63Oj4zI4rLbkRycorsscXeD-IpKzt9N1MhkMa5o,10374
44
+ optimum/rbln/ops/__init__.py,sha256=Wv2cJhEw8mqc6-To24bHzf4qQL8gM0Zh_2Ck77LB65g,947
45
+ optimum/rbln/ops/attn.py,sha256=OSgPoEgCwvR7HdjbnaVkFVMBcJ5RpRWcE6OCg2lVyGk,10634
46
46
  optimum/rbln/ops/flash_attn.py,sha256=wfyiCxDGf034IngzwRU160R7_DlKYpd-uWT0BDEGFks,3408
47
47
  optimum/rbln/ops/kv_cache_update.py,sha256=pxf8kAptPaQF5xE8qItvmlFOq_sgim6ZERD7AVaOtec,3221
48
48
  optimum/rbln/transformers/__init__.py,sha256=AGo3BqVIZrsOzYsQAnnQ25HCstTPBclrXbvvUxVMlqE,4255
@@ -55,7 +55,7 @@ optimum/rbln/transformers/models/auto/auto_factory.py,sha256=IK9jFrJ3EEzYQa9_aKp
55
55
  optimum/rbln/transformers/models/auto/modeling_auto.py,sha256=Un9qoqdy3dO8JBza_bTJF_6_fRVNM9QisihSgTRFI-o,3933
56
56
  optimum/rbln/transformers/models/bart/__init__.py,sha256=32HPe0_GIO0hp9U464Iv6Jd7M-1nop9g8hA1UZMHhyw,674
57
57
  optimum/rbln/transformers/models/bart/bart_architecture.py,sha256=Oo-Cdne7igKEex8wwP-gztKJHgs5GLHQjK1oc3IZIDE,5801
58
- optimum/rbln/transformers/models/bart/modeling_bart.py,sha256=iI3ubPOVvHmhLt0wEz_vkOfMyNTHVNjmnkLtbpOX760,5797
58
+ optimum/rbln/transformers/models/bart/modeling_bart.py,sha256=CUF5PE9TxJxtO1VpuGgeKrL_u6PdsKxstlZDthYSXgU,5829
59
59
  optimum/rbln/transformers/models/bert/__init__.py,sha256=YVV7k_laU6yJBawZrgjIWjRmIF-Y4oQQHqyf8lsraQs,691
60
60
  optimum/rbln/transformers/models/bert/modeling_bert.py,sha256=p3utRqf3dv9_RkHwaMCa1EfXttNJkqCJUIZo3CeZ9YY,4674
61
61
  optimum/rbln/transformers/models/clip/__init__.py,sha256=H9vuBwrmFO0-CqZhXUrKF-uQL6igCqMlqrT1X_ELaAI,754
@@ -92,17 +92,17 @@ optimum/rbln/transformers/models/qwen2/__init__.py,sha256=RAMWc21W_2I6DH9xBjeNxP
92
92
  optimum/rbln/transformers/models/qwen2/modeling_qwen2.py,sha256=9-aFDvjMzPNUyGOz0qo33RE18bUFGYZ3Wt_68zb5uJY,1530
93
93
  optimum/rbln/transformers/models/qwen2/qwen2_architecture.py,sha256=XlNAMYAcDLohnSAhIFGKOPuCB5XLgzYs5ABWdeQSaZs,720
94
94
  optimum/rbln/transformers/models/seq2seq/__init__.py,sha256=EmEMV4rOYqKyruX85d0fR73-b8N6BSD6CPcbpYdBuVk,651
95
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py,sha256=NPfJf9Uk_bYOae7hXGHwteGiWH0va63Z-D93RmAMENg,17611
95
+ optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py,sha256=QelhuCWEHPL2Ut7fm0gLnzTVveBAaKSNpoa9X1AmwTI,17709
96
96
  optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py,sha256=tvzacIZam1sIr_1BvvZ_fDr8u5dXAiYiynFdX9tArtY,18877
97
97
  optimum/rbln/transformers/models/t5/__init__.py,sha256=1skR1RmnG62WTAP3-F5P1x-V_ReFhMyirH3u56vWwvc,675
98
- optimum/rbln/transformers/models/t5/modeling_t5.py,sha256=nKRR3eH1EAu1YkKvhlqGyTrJXDRd-IWB5LOeG9jrcb4,16021
99
- optimum/rbln/transformers/models/t5/t5_architecture.py,sha256=AArCQhZRETVM583wlIRzMFOSYq7t2nzxaAeyhZxyxKk,9508
98
+ optimum/rbln/transformers/models/t5/modeling_t5.py,sha256=-fG-h0wwsfjZ3par0QHbXKA7hbvw_lPJOIf8iXQDOfM,8082
99
+ optimum/rbln/transformers/models/t5/t5_architecture.py,sha256=Ups6drBbYe4wEAiBLcBIyO9wqrIQbvOPFR_ybbAgR8c,9722
100
100
  optimum/rbln/transformers/models/wav2vec2/__init__.py,sha256=YpgA0K-vyg9veh0eL_jxauosbRpb_kpGKHvvQLBspKM,649
101
101
  optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py,sha256=JYJmV52j6cBwim4RanVJryfKnV80V96ol0A-oR6o7cg,3856
102
102
  optimum/rbln/transformers/models/whisper/__init__.py,sha256=ktnNe5ri3ycCWZ_W_voFB9y9-vgGgxS1X9s8LBRZmWc,665
103
103
  optimum/rbln/transformers/models/whisper/generation_whisper.py,sha256=GIHTca3b1VtW81kp7BzKQ7f77c2t9OsEsbZetripgDo,4582
104
- optimum/rbln/transformers/models/whisper/modeling_whisper.py,sha256=0nBADNxE0A1ozBbRutTBvxpo_Y1qkOycT_zronkN-ZU,15840
105
- optimum/rbln/transformers/models/whisper/whisper_architecture.py,sha256=Yn6yFpmw6IQbWlnpIMAdEUsNF6huXgaKzGMUZbhSLdo,12572
104
+ optimum/rbln/transformers/models/whisper/modeling_whisper.py,sha256=U9zK49DcSdXuoK_UOsVPsyKe6EJ5CQR8QZhpgi23EUU,16275
105
+ optimum/rbln/transformers/models/whisper/whisper_architecture.py,sha256=ArQPOgiRVu-XddEN5FXVl1OlCoGF6uY7jGoWTj3Nfe4,13005
106
106
  optimum/rbln/transformers/models/xlm_roberta/__init__.py,sha256=fC7iNcdxBZ_6eOF2snStmf8r2M3c8O_-XcXnQEaHQCE,653
107
107
  optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py,sha256=8YNLz0bc5ze-QuU8rN-QhUfGzlSUs3iMJiWTxO3o6AM,4366
108
108
  optimum/rbln/transformers/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -116,7 +116,7 @@ optimum/rbln/utils/model_utils.py,sha256=DfD_Z2qvZHqcddXqnzTM1AN8khanj3-DXK2lJvV
116
116
  optimum/rbln/utils/runtime_utils.py,sha256=5-DYniyP59nx-mrrbi7AqA77L85b4Cm5oLpaxidSyss,3699
117
117
  optimum/rbln/utils/save_utils.py,sha256=hG5uOtYmecSXZuGTvCXsTM-SiyZpr5q3InUGCCq_jzQ,3619
118
118
  optimum/rbln/utils/submodule.py,sha256=oZoGrItB8WqY4i-K9WJPlLlcLohc1YGB9OHB8_XZw3A,4071
119
- optimum_rbln-0.7.3.post2.dist-info/METADATA,sha256=YgOp5SEpJ_VfYEohAoBhSQ20TaX1usvkRAzV7s7mS5I,5304
120
- optimum_rbln-0.7.3.post2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
121
- optimum_rbln-0.7.3.post2.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
122
- optimum_rbln-0.7.3.post2.dist-info/RECORD,,
119
+ optimum_rbln-0.7.4a0.dist-info/METADATA,sha256=tXU0EmgjFJug_Cvmw8S9NeEZ2z9XpgamFwgMQTTCa1U,5300
120
+ optimum_rbln-0.7.4a0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
121
+ optimum_rbln-0.7.4a0.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
122
+ optimum_rbln-0.7.4a0.dist-info/RECORD,,