optimum-rbln 0.7.3a1__py3-none-any.whl → 0.7.3a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,13 +12,17 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Tuple
15
+ from typing import Optional, Tuple
16
16
 
17
17
  import torch
18
18
  from torch import nn
19
19
  from transformers.utils import logging
20
20
 
21
- from ....ops import register_rbln_custom_cache_update, register_rbln_custom_masked_attention
21
+ from ....ops import (
22
+ register_rbln_custom_cache_update,
23
+ register_rbln_custom_paged_attention,
24
+ register_rbln_custom_paged_causal_attention,
25
+ )
22
26
 
23
27
 
24
28
  logger = logging.get_logger(__name__)
@@ -87,7 +91,7 @@ class Seq2SeqEncoderWrapper(nn.Module):
87
91
  input_ids: torch.Tensor,
88
92
  attention_mask: torch.Tensor,
89
93
  cross_key_values: torch.Tensor,
90
- batch_position: torch.Tensor,
94
+ b_idx: torch.Tensor,
91
95
  ) -> Tuple[torch.Tensor]:
92
96
  # 1. get encoder last_hidden_states
93
97
  encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
@@ -111,7 +115,7 @@ class Seq2SeqEncoderWrapper(nn.Module):
111
115
  # 3. update the cross_attention's past_key_value direct to the device-dram for optimization.
112
116
  batch_axis = torch.tensor(1, dtype=torch.int16)
113
117
  cross_key_values = torch.ops.rbln_custom_ops.rbln_cache_update(
114
- cross_key_values, cross_kv, batch_position, batch_axis
118
+ cross_key_values, cross_kv, b_idx[0], batch_axis
115
119
  )
116
120
 
117
121
  return cross_key_values
@@ -131,9 +135,10 @@ class Seq2SeqDecoderWrapper(nn.Module):
131
135
  **kwargs: Additional arguments for decoder configuration.
132
136
  """
133
137
 
134
- def __init__(self, model: nn.Module, **kwargs):
138
+ def __init__(self, model: nn.Module, use_attention_mask: bool = True, **kwargs):
135
139
  super().__init__()
136
140
  self.config = model.config
141
+ self.use_attention_mask = use_attention_mask
137
142
  self.__post_init__(model, **kwargs)
138
143
 
139
144
  def __post_init__(self, model: nn.Module, **kwargs):
@@ -143,7 +148,11 @@ class Seq2SeqDecoderWrapper(nn.Module):
143
148
  It is inspired by the BART architecture, but it is designed to be flexible and can be overridden
144
149
  by subclasses to modify or add custom attributes as necessary.
145
150
  """
146
- register_rbln_custom_masked_attention()
151
+ if self.use_attention_mask:
152
+ register_rbln_custom_paged_attention()
153
+ else:
154
+ register_rbln_custom_paged_causal_attention()
155
+
147
156
  self.num_layers = self.config.decoder_layers
148
157
  self.conditional_generation = self.convert_to_rbln_conditional_generation(model)
149
158
 
@@ -160,13 +169,23 @@ class Seq2SeqDecoderWrapper(nn.Module):
160
169
 
161
170
  def forward(
162
171
  self,
163
- input_ids: torch.Tensor,
164
- attention_mask: torch.Tensor,
165
- encoder_attention_mask: torch.Tensor,
166
- cache_position: torch.Tensor,
167
- cross_kv_cache: torch.Tensor,
168
- *self_kv_cache: torch.Tensor,
172
+ *args,
169
173
  ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
174
+ if self.use_attention_mask:
175
+ (
176
+ input_ids,
177
+ attention_mask,
178
+ encoder_attention_mask,
179
+ cache_position,
180
+ block_tables,
181
+ cross_kv_cache,
182
+ *self_kv_cache,
183
+ ) = args
184
+
185
+ else:
186
+ attention_mask = None
187
+ (input_ids, encoder_attention_mask, cache_position, block_tables, cross_kv_cache, *self_kv_cache) = args
188
+
170
189
  self_past_key_values = ()
171
190
  cross_past_key_values = ()
172
191
  for i in range(0, self.num_layers * 2, 2):
@@ -181,6 +200,7 @@ class Seq2SeqDecoderWrapper(nn.Module):
181
200
  self_past_key_values=self_past_key_values,
182
201
  cross_past_key_values=cross_past_key_values,
183
202
  cache_position=cache_position,
203
+ block_tables=block_tables,
184
204
  )
185
205
 
186
206
  outputs = (lm_logits,) + self_present_key_values
@@ -228,6 +248,7 @@ class Seq2SeqForConditionalGeneration(nn.Module):
228
248
  self_past_key_values,
229
249
  cross_past_key_values,
230
250
  cache_position,
251
+ block_tables: Optional[torch.Tensor] = None,
231
252
  ):
232
253
  hidden_states, self_present_key_values = self.decoder(
233
254
  input_ids=input_ids,
@@ -236,6 +257,7 @@ class Seq2SeqForConditionalGeneration(nn.Module):
236
257
  self_past_key_values=self_past_key_values,
237
258
  cross_past_key_values=cross_past_key_values,
238
259
  cache_position=cache_position,
260
+ block_tables=block_tables,
239
261
  )
240
262
 
241
263
  if self.has_rescaling and self.config.tie_word_embeddings:
@@ -292,6 +314,7 @@ class Seq2SeqDecoder(torch.nn.Module):
292
314
  self_past_key_values: torch.Tensor,
293
315
  cross_past_key_values: torch.Tensor,
294
316
  cache_position: torch.Tensor,
317
+ block_tables: Optional[torch.Tensor] = None,
295
318
  ):
296
319
  # embedding
297
320
  hidden_states = self.get_embedding()(input_ids)
@@ -314,6 +337,7 @@ class Seq2SeqDecoder(torch.nn.Module):
314
337
  self_past_key_value=self_past_key_value,
315
338
  cross_past_key_value=cross_past_key_value,
316
339
  cache_position=cache_position,
340
+ block_tables=block_tables,
317
341
  )
318
342
  self_present_key_values += self_present_key_value
319
343
 
@@ -373,6 +397,7 @@ class Seq2SeqDecoderLayer(torch.nn.Module):
373
397
  self_past_key_value: Tuple[torch.Tensor],
374
398
  cross_past_key_value: Tuple[torch.Tensor],
375
399
  cache_position: torch.Tensor,
400
+ block_tables: Optional[torch.Tensor] = None,
376
401
  ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
377
402
  dummy_encoder_hidden_states = torch.zeros(1, encoder_attention_mask.shape[-1])
378
403
 
@@ -384,6 +409,7 @@ class Seq2SeqDecoderLayer(torch.nn.Module):
384
409
  past_key_value=self_past_key_value,
385
410
  attention_mask=attention_mask,
386
411
  cache_position=cache_position,
412
+ block_tables=block_tables,
387
413
  )
388
414
  hidden_states = residual + hidden_states
389
415
  hidden_states = self.post_self_attn_layer_norm(hidden_states)
@@ -407,10 +433,10 @@ class Seq2SeqDecoderLayer(torch.nn.Module):
407
433
 
408
434
 
409
435
  class Seq2SeqSelfAttention(nn.Module):
410
- def __init__(self, attn):
436
+ def __init__(self, attn, **kwargs):
411
437
  super().__init__()
412
438
  self._original_mod = attn
413
- self.__post_init__()
439
+ self.__post_init__(**kwargs)
414
440
 
415
441
  def __post_init__(self, **kwargs):
416
442
  """
@@ -442,6 +468,7 @@ class Seq2SeqSelfAttention(nn.Module):
442
468
  past_key_value: Tuple[torch.Tensor],
443
469
  attention_mask: torch.Tensor,
444
470
  cache_position: torch.Tensor,
471
+ block_tables: Optional[torch.Tensor] = None,
445
472
  ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
446
473
  bsz, tgt_len, _ = hidden_states.size()
447
474
 
@@ -450,18 +477,22 @@ class Seq2SeqSelfAttention(nn.Module):
450
477
  key_states = self._shape(key_states, -1, bsz)
451
478
  value_states = self._shape(value_states, -1, bsz)
452
479
 
453
- attn_output, key_states, value_states = self.attn_decode(
480
+ block_size = past_key_value[0].shape[-2]
481
+ args = [
454
482
  query_states,
455
483
  key_states,
456
484
  value_states,
457
- attention_mask.unsqueeze(
458
- 2
459
- ), # Unsqueeze group axis since CustomKernel expects it for group query attention
460
485
  past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim),
461
486
  past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
462
487
  cache_position,
463
488
  torch.tensor(1.0, dtype=torch.float32), # scale
464
- )
489
+ block_tables,
490
+ block_size,
491
+ ]
492
+ if attention_mask is not None:
493
+ args.insert(3, attention_mask.unsqueeze(2))
494
+
495
+ attn_output, key_states, value_states = self.attn_decode(*args)
465
496
 
466
497
  attn_output = attn_output.view(bsz, self.num_heads, -1, self.head_dim).transpose(1, 2)
467
498
  attn_output = attn_output.reshape(bsz, -1, self.num_heads * self.head_dim)
@@ -13,8 +13,9 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
- from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
16
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
17
17
 
18
+ import rebel
18
19
  import torch
19
20
  from transformers import (
20
21
  AutoModelForTextEncoding,
@@ -22,7 +23,7 @@ from transformers import (
22
23
  T5EncoderModel,
23
24
  T5ForConditionalGeneration,
24
25
  )
25
- from transformers.modeling_outputs import BaseModelOutput
26
+ from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
26
27
 
27
28
  from ....diffusers.modeling_diffusers import RBLNDiffusionMixin
28
29
  from ....modeling import RBLNModel
@@ -57,6 +58,63 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
57
58
  )
58
59
 
59
60
 
61
+ class RBLNRuntimeEncoder(RBLNPytorchRuntime):
62
+ mandatory_members = ["main_input_name"]
63
+
64
+ def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
65
+ _ = super().forward(*args, **kwargs)
66
+ return BaseModelOutput(last_hidden_state=torch.tensor([1.0]))
67
+
68
+
69
+ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
70
+ mandatory_members = ["main_input_name"]
71
+
72
+ def __init__(
73
+ self,
74
+ runtime: rebel.Runtime,
75
+ batch_size: int,
76
+ dec_max_seq_len: int,
77
+ **kwargs: Any,
78
+ ) -> None:
79
+ super().__init__(runtime, **kwargs)
80
+ self.batch_size = batch_size
81
+ self.dec_max_seq_len = dec_max_seq_len
82
+
83
+ def forward(
84
+ self,
85
+ decoder_input_ids: Optional[torch.LongTensor] = None,
86
+ attention_mask: Optional[torch.FloatTensor] = None,
87
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
88
+ cache_position: Optional[torch.Tensor] = None,
89
+ **kwargs,
90
+ ) -> Tuple[torch.FloatTensor]:
91
+ batch_size = decoder_input_ids.shape[0]
92
+ if batch_size != self.batch_size:
93
+ raise RuntimeError(
94
+ f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
95
+ )
96
+
97
+ if batch_size != cache_position.shape[0]:
98
+ raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
99
+
100
+ for b_idx in range(self.batch_size):
101
+ decoding_step = cache_position[b_idx].item()
102
+ if not (0 <= decoding_step < self.dec_max_seq_len):
103
+ raise ValueError(
104
+ f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
105
+ )
106
+ decoder_attention_mask[b_idx, : decoding_step + 1] = 1
107
+
108
+ lm_logits = super().forward(
109
+ decoder_input_ids,
110
+ decoder_attention_mask,
111
+ attention_mask,
112
+ cache_position,
113
+ )
114
+
115
+ return Seq2SeqLMOutput(logits=lm_logits)
116
+
117
+
60
118
  class T5EncoderWrapper(torch.nn.Module):
61
119
  def __init__(self, model: "T5EncoderModel") -> None:
62
120
  super().__init__()
@@ -189,6 +247,21 @@ class RBLNT5EncoderModel(RBLNModel):
189
247
 
190
248
 
191
249
  class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
250
+ def __post_init__(self, **kwargs):
251
+ batch_size = self.rbln_config.model_cfg["batch_size"]
252
+ dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
253
+
254
+ self.encoder = RBLNRuntimeEncoder(
255
+ runtime=self.model[0],
256
+ main_input_name="input_ids",
257
+ )
258
+ self.decoder = RBLNRuntimeDecoder(
259
+ runtime=self.model[1],
260
+ main_input_name="input_ids",
261
+ batch_size=batch_size,
262
+ dec_max_seq_len=dec_max_seq_len,
263
+ )
264
+
192
265
  @classmethod
193
266
  def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
194
267
  enc_max_seq_len = rbln_config.model_cfg["enc_max_seq_len"]
@@ -206,3 +279,139 @@ class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
206
279
  return redirect(val)
207
280
 
208
281
  return val
282
+
283
+ @classmethod
284
+ def _get_rbln_config(
285
+ cls,
286
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
287
+ model_config: "PretrainedConfig",
288
+ rbln_kwargs: Dict[str, Any] = {},
289
+ ) -> RBLNConfig:
290
+ rbln_enc_max_seq_len = rbln_kwargs.get("enc_max_seq_len", None)
291
+ rbln_dec_max_seq_len = rbln_kwargs.get("dec_max_seq_len", None)
292
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
293
+ rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
294
+
295
+ n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
296
+ n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
297
+ d_kv = (
298
+ model_config.d_kv
299
+ if hasattr(model_config, "d_kv")
300
+ else model_config.d_model // model_config.encoder_attention_heads
301
+ )
302
+
303
+ max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
304
+ model_config, "max_position_embeddings", None
305
+ )
306
+
307
+ rbln_pad_token_id = getattr(model_config, "pad_token_id", None)
308
+ if rbln_pad_token_id is None:
309
+ rbln_pad_token_id = getattr(model_config, "bos_token_id", None)
310
+ if rbln_pad_token_id is None:
311
+ rbln_pad_token_id = getattr(model_config, "eos_token_id", None)
312
+ if rbln_pad_token_id is None:
313
+ rbln_pad_token_id = -1
314
+
315
+ if rbln_enc_max_seq_len is None:
316
+ rbln_enc_max_seq_len = max_position_embeddings
317
+ if rbln_enc_max_seq_len is None:
318
+ for tokenizer in preprocessors:
319
+ if hasattr(tokenizer, "model_max_length"):
320
+ rbln_enc_max_seq_len = tokenizer.model_max_length
321
+ break
322
+ if rbln_enc_max_seq_len is None:
323
+ raise ValueError("`rbln_enc_max_seq_len` should be specified!")
324
+ if max_position_embeddings is not None and rbln_enc_max_seq_len > max_position_embeddings:
325
+ raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
326
+
327
+ if rbln_dec_max_seq_len is None:
328
+ rbln_dec_max_seq_len = max_position_embeddings
329
+ if rbln_dec_max_seq_len is None:
330
+ for tokenizer in preprocessors:
331
+ if hasattr(tokenizer, "model_max_length"):
332
+ rbln_dec_max_seq_len = tokenizer.model_max_length
333
+ break
334
+ if rbln_dec_max_seq_len is None:
335
+ raise ValueError("`rbln_dec_max_seq_len` should be specified!")
336
+
337
+ if max_position_embeddings is not None and rbln_dec_max_seq_len > max_position_embeddings:
338
+ raise ValueError("`rbln_dec_max_seq_len` should be less or equal than max_position_embeddings!")
339
+
340
+ # model input info
341
+ enc_input_info = [
342
+ ("input_ids", [1, rbln_enc_max_seq_len], "int64"),
343
+ ("attention_mask", [1, rbln_enc_max_seq_len], "float32"),
344
+ (
345
+ "cross_key_value_states",
346
+ [
347
+ n_layer * 2,
348
+ rbln_batch_size,
349
+ n_head,
350
+ rbln_enc_max_seq_len,
351
+ d_kv,
352
+ ],
353
+ "float32",
354
+ ),
355
+ ("block_tables", [1], "int16"),
356
+ ]
357
+
358
+ dec_input_info = [
359
+ ("input_ids", [rbln_batch_size, 1], "int64"),
360
+ ("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"),
361
+ ("encoder_attention_mask", [rbln_batch_size, rbln_enc_max_seq_len], "float32"),
362
+ (
363
+ "cache_position",
364
+ [rbln_batch_size, 1],
365
+ "int32",
366
+ ),
367
+ ]
368
+ dec_input_info.extend(
369
+ [
370
+ (
371
+ "cross_key_value_states",
372
+ [
373
+ n_layer * 2,
374
+ rbln_batch_size,
375
+ n_head,
376
+ rbln_enc_max_seq_len,
377
+ d_kv,
378
+ ],
379
+ "float32",
380
+ )
381
+ ]
382
+ )
383
+ dec_input_info.extend(
384
+ [
385
+ (
386
+ f"self_key_value_states_{i}",
387
+ [
388
+ rbln_batch_size,
389
+ n_head,
390
+ rbln_dec_max_seq_len,
391
+ d_kv,
392
+ ],
393
+ "float32",
394
+ )
395
+ for i in range(n_layer * 2)
396
+ ]
397
+ )
398
+
399
+ enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
400
+ dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
401
+
402
+ rbln_config = RBLNConfig(
403
+ rbln_cls=cls.__name__,
404
+ compile_cfgs=[enc_compile_config, dec_compile_config],
405
+ rbln_kwargs=rbln_kwargs,
406
+ )
407
+
408
+ rbln_config.model_cfg.update(
409
+ {
410
+ "enc_max_seq_len": rbln_enc_max_seq_len,
411
+ "dec_max_seq_len": rbln_dec_max_seq_len,
412
+ "batch_size": rbln_batch_size,
413
+ "pad_token_id": rbln_pad_token_id,
414
+ }
415
+ )
416
+
417
+ return rbln_config
@@ -18,7 +18,7 @@ import torch
18
18
  from torch import nn
19
19
  from transformers.utils import logging
20
20
 
21
- from ....ops import register_rbln_custom_attention_add_softmax
21
+ from ....ops import register_rbln_custom_add_softmax_attention
22
22
  from ..seq2seq.seq2seq_architecture import (
23
23
  Seq2SeqDecoder,
24
24
  Seq2SeqDecoderLayer,
@@ -55,7 +55,7 @@ class T5EncoderWrapper(Seq2SeqEncoderWrapper):
55
55
 
56
56
  class T5DecoderWrapper(Seq2SeqDecoderWrapper):
57
57
  def __post_init__(self, model, dec_max_seq_len: int = None):
58
- register_rbln_custom_attention_add_softmax()
58
+ register_rbln_custom_add_softmax_attention()
59
59
  self.num_layers = self.config.num_layers
60
60
  self.conditional_generation = self.convert_to_rbln_conditional_generation(model, dec_max_seq_len)
61
61
 
@@ -71,6 +71,36 @@ class T5DecoderWrapper(Seq2SeqDecoderWrapper):
71
71
 
72
72
  return new_model
73
73
 
74
+ def forward(
75
+ self,
76
+ input_ids,
77
+ attention_mask,
78
+ encoder_attention_mask,
79
+ cache_position,
80
+ cross_kv_cache,
81
+ *self_kv_cache,
82
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
83
+ self_past_key_values = ()
84
+ cross_past_key_values = ()
85
+
86
+ for i in range(0, self.num_layers * 2, 2):
87
+ self_past_key_values = self_past_key_values + ((self_kv_cache[i], self_kv_cache[i + 1]),)
88
+ cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
89
+
90
+ # decode
91
+ lm_logits, self_present_key_values = self.conditional_generation(
92
+ input_ids=input_ids,
93
+ attention_mask=attention_mask,
94
+ encoder_attention_mask=encoder_attention_mask,
95
+ self_past_key_values=self_past_key_values,
96
+ cross_past_key_values=cross_past_key_values,
97
+ cache_position=cache_position,
98
+ )
99
+
100
+ outputs = (lm_logits,) + self_present_key_values
101
+
102
+ return outputs
103
+
74
104
 
75
105
  class T5ForConditionalGeneration(Seq2SeqForConditionalGeneration):
76
106
  has_rescaling = True
@@ -134,7 +164,7 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
134
164
  self.out_proj = self._original_mod.o
135
165
  self.num_heads = self._original_mod.n_heads
136
166
  self.head_dim = self._original_mod.key_value_proj_dim
137
- self.attn_decode = torch.ops.rbln_custom_ops.attn_decode_add_softmax
167
+ self.attn_decode = torch.ops.rbln_custom_ops.add_softmax_attn_decode
138
168
 
139
169
  def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
140
170
  query_states = self.q_proj(hidden_states)
@@ -142,6 +172,42 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
142
172
  value_states = self.v_proj(hidden_states)
143
173
  return query_states, key_states, value_states
144
174
 
175
+ def forward(
176
+ self,
177
+ hidden_states: torch.Tensor,
178
+ past_key_value: Tuple[torch.Tensor],
179
+ attention_mask: torch.Tensor,
180
+ cache_position: torch.Tensor,
181
+ **kwargs,
182
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
183
+ bsz, tgt_len, _ = hidden_states.size()
184
+
185
+ query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
186
+ query_states = self._shape(query_states, tgt_len, bsz)
187
+ key_states = self._shape(key_states, -1, bsz)
188
+ value_states = self._shape(value_states, -1, bsz)
189
+
190
+ attn_output, key_states, value_states = self.attn_decode(
191
+ query_states,
192
+ key_states,
193
+ value_states,
194
+ attention_mask.unsqueeze(
195
+ 2
196
+ ), # Unsqueeze group axis since CustomKernel expects it for group query attention
197
+ past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim),
198
+ past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
199
+ cache_position,
200
+ torch.tensor(1.0, dtype=torch.float32), # scale
201
+ )
202
+
203
+ attn_output = attn_output.view(bsz, self.num_heads, -1, self.head_dim).transpose(1, 2)
204
+ attn_output = attn_output.reshape(bsz, -1, self.num_heads * self.head_dim)
205
+
206
+ attn_output = self.out_proj(attn_output)
207
+ present_key_value = (key_states, value_states)
208
+
209
+ return attn_output, present_key_value
210
+
145
211
 
146
212
  class T5CrossAttention(nn.Module):
147
213
  def __init__(self, attn):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: optimum-rbln
3
- Version: 0.7.3a1
3
+ Version: 0.7.3a2
4
4
  Summary: Optimum RBLN is the interface between the Hugging Face Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
5
5
  Project-URL: Homepage, https://rebellions.ai
6
6
  Project-URL: Documentation, https://docs.rbln.ai
@@ -1,5 +1,5 @@
1
1
  optimum/rbln/__init__.py,sha256=eHi15YM3989AcX52jka9rUmgAtlp1PHqMNwBEdOfuu8,6554
2
- optimum/rbln/__version__.py,sha256=3XXLT-7KoOXBM5ecjGQ9vxdHcJ06x38tTkK1veoUkmQ,513
2
+ optimum/rbln/__version__.py,sha256=bShBukYvw7AqWtLsut0yClygDEGsFRmxrXypqIeEXcQ,513
3
3
  optimum/rbln/modeling.py,sha256=3XE0IrCYbkjw9_Q9BFzZ_ri_Kyxw1g6iwfdohZB46-s,8289
4
4
  optimum/rbln/modeling_base.py,sha256=ELSPbjx7awBRM2SckkD-5gI3TIa01mfzz7gDRC1Pljk,21778
5
5
  optimum/rbln/modeling_config.py,sha256=7104bxmrvKW4Q6XTruQayiIGl8GHDFmPkJ3cknMIInE,11335
@@ -39,9 +39,9 @@ optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py,sha256=9iIMZYvp
39
39
  optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py,sha256=OvB5bxX6HUiqJeIc3uukuEmUXYEx1pTqGNOtdG2l1m8,902
40
40
  optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py,sha256=3aB1Rw-OgKytQOHwOaShbEvq_XVHPOGvsGm8pstEmKU,930
41
41
  optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py,sha256=MzVP1wscaO1sUIiBIPJqG6zuGyez9VUbA42-JSIm-mk,930
42
- optimum/rbln/ops/__init__.py,sha256=-jcOGX3B8w5Znpr1z2eUsrK3TN-w9trrkSoqJRWgXdA,945
43
- optimum/rbln/ops/attn.py,sha256=WUsy4I25gm2j9Xdns9W2NNd3jCNcueqJuisDzp0jPaA,13899
44
- optimum/rbln/ops/flash_attn.py,sha256=aQRupKPvJsNFWKrHaeyXg-LemyUJWmCJaVrA__Mjabo,3869
42
+ optimum/rbln/ops/__init__.py,sha256=TxOmsN0u3PmyK4Sb89qbiC4rePOlkvUT7Lm6wVoTnY0,941
43
+ optimum/rbln/ops/attn.py,sha256=LbJAmFtNj05i6BURfKV3KybsPItFe8w-YdSe5SuWkEc,12365
44
+ optimum/rbln/ops/flash_attn.py,sha256=4shKNY13skPoYnbEsGrXDzgNwBIhHZEFrnUnWx1ESZU,4076
45
45
  optimum/rbln/ops/kv_cache_update.py,sha256=9W4WCO1Dtfy0u5i978JJRa7uLbqrfR2lHuoPynb07fw,3143
46
46
  optimum/rbln/transformers/__init__.py,sha256=AGo3BqVIZrsOzYsQAnnQ25HCstTPBclrXbvvUxVMlqE,4255
47
47
  optimum/rbln/transformers/modeling_alias.py,sha256=yx7FnZQWAnrWzivaO5hI7T6i-fyLzt2tMIXG2oDNbPo,1657
@@ -52,25 +52,25 @@ optimum/rbln/transformers/models/auto/__init__.py,sha256=GvGbb3ZpMv-h6euXeZ42jSi
52
52
  optimum/rbln/transformers/models/auto/auto_factory.py,sha256=IK9jFrJ3EEzYQa9_aKpcp2TO68M5YGkA-HcfBVpA2QU,7027
53
53
  optimum/rbln/transformers/models/auto/modeling_auto.py,sha256=Un9qoqdy3dO8JBza_bTJF_6_fRVNM9QisihSgTRFI-o,3933
54
54
  optimum/rbln/transformers/models/bart/__init__.py,sha256=32HPe0_GIO0hp9U464Iv6Jd7M-1nop9g8hA1UZMHhyw,674
55
- optimum/rbln/transformers/models/bart/bart_architecture.py,sha256=ZV-6Y3xABJsGAw2wi3vGYZNXbeVp-DlI2uUsdsa-8M8,5486
56
- optimum/rbln/transformers/models/bart/modeling_bart.py,sha256=QZCTJA0K90YBzkCXxs3JR9Ol9lbmAn50RDeN2hcWtx8,5673
55
+ optimum/rbln/transformers/models/bart/bart_architecture.py,sha256=Oo-Cdne7igKEex8wwP-gztKJHgs5GLHQjK1oc3IZIDE,5801
56
+ optimum/rbln/transformers/models/bart/modeling_bart.py,sha256=iI3ubPOVvHmhLt0wEz_vkOfMyNTHVNjmnkLtbpOX760,5797
57
57
  optimum/rbln/transformers/models/bert/__init__.py,sha256=YVV7k_laU6yJBawZrgjIWjRmIF-Y4oQQHqyf8lsraQs,691
58
58
  optimum/rbln/transformers/models/bert/modeling_bert.py,sha256=p3utRqf3dv9_RkHwaMCa1EfXttNJkqCJUIZo3CeZ9YY,4674
59
59
  optimum/rbln/transformers/models/clip/__init__.py,sha256=H9vuBwrmFO0-CqZhXUrKF-uQL6igCqMlqrT1X_ELaAI,754
60
60
  optimum/rbln/transformers/models/clip/modeling_clip.py,sha256=NiSm7bHs4SReHDUr53BBWSX0Y8bkKOeUSpsBDrp8YDw,6628
61
61
  optimum/rbln/transformers/models/decoderonly/__init__.py,sha256=pDogsdpJKKB5rqnVFrRjwfhUvOSV-jZ3oARMsqSvOOQ,665
62
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py,sha256=6X87ZVvz4wHoTATdaxxSLy8wBfsEkUwWQISlo_mXPKM,40822
63
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py,sha256=BYENVqueqR121nPyh2LnV_sMdhl95GRgqHLnCcX2sz8,29067
62
+ optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py,sha256=x8_xQ5aGXbadJyajpJQyfgxx4YPSj62VlmmGDMnC-1E,41819
63
+ optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py,sha256=dyl8tDBjfe5VfU1XbKAoZS7g7F90JTYVmMuz0HTmCoE,35345
64
64
  optimum/rbln/transformers/models/dpt/__init__.py,sha256=gP1tkR3XMNlHq1GT87ugIVvb2o_1eAUg1JaniXjy1Lw,651
65
65
  optimum/rbln/transformers/models/dpt/modeling_dpt.py,sha256=ZsS2SOiqcA4azULB-WFEMQZbgIoOyVUKqVKqrw_tWzA,3430
66
66
  optimum/rbln/transformers/models/exaone/__init__.py,sha256=zYH_5tVa8-juEdsOIky7I33WSC3Zuhoq1upI0OHYeVw,859
67
- optimum/rbln/transformers/models/exaone/exaone_architecture.py,sha256=aPit1EOe3s3g3IVhztU1wydiTjYGA_j02btV9dl8W_I,3119
67
+ optimum/rbln/transformers/models/exaone/exaone_architecture.py,sha256=ZM5vvz8KBipOiMVi8vqfvejkDSknW69xh4GrvJix-g0,3350
68
68
  optimum/rbln/transformers/models/exaone/modeling_exaone.py,sha256=WjyH8PmsMljSea7kJn_Cq1FJ96OXwXAoU7hv2Q8zUnI,1747
69
69
  optimum/rbln/transformers/models/gemma/__init__.py,sha256=7qUrekuBwCI9a6_Fq6j--FzCirRtUDz3ApY17mQS5Y4,648
70
- optimum/rbln/transformers/models/gemma/gemma_architecture.py,sha256=_GPIcSY5Q3PPuTehEseEf43mMBkW9Gl6pJlnHnjmkkM,2055
70
+ optimum/rbln/transformers/models/gemma/gemma_architecture.py,sha256=bmCx405FVcffhgrQ53qMMZDbSlPxWOjucMHbvq19Gnw,2286
71
71
  optimum/rbln/transformers/models/gemma/modeling_gemma.py,sha256=-U3w3cEOv3ps1S8aL7uOq6Kq2siCPZz7Z8MXhDQgQqo,1530
72
72
  optimum/rbln/transformers/models/gpt2/__init__.py,sha256=UwwPPYVTB9ywDWy314L2bNL0i7wfkQFA71qjgXicEPg,646
73
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py,sha256=eN_UFcaaxWrtXvAGYck7J19Im2GZub2pwOejF3VWR6I,2934
73
+ optimum/rbln/transformers/models/gpt2/gpt2_architecture.py,sha256=1IxqHmB-GlH2Dv2Yk4z0rMxL9CpxMGHhSu_x8_4cxvs,3008
74
74
  optimum/rbln/transformers/models/gpt2/modeling_gpt2.py,sha256=qBDanUk_O-HtOIVCA4IE3FYyCsnL9xIDK00vft-0caw,1490
75
75
  optimum/rbln/transformers/models/llama/__init__.py,sha256=jo_j_eIrHYGNEhR5lb6g3r5sO0ewe3fm2TWO8mLrT58,648
76
76
  optimum/rbln/transformers/models/llama/llama_architecture.py,sha256=S7MCPfyjG5eUqgaS-QNBB0ApUD6wnb5fR0RHq7k7-pA,728
@@ -78,23 +78,23 @@ optimum/rbln/transformers/models/llama/modeling_llama.py,sha256=Z3iony7icoFhRQ11
78
78
  optimum/rbln/transformers/models/llava_next/__init__.py,sha256=VLieyWm-UgvuNxw9B38wrL1Jsa09NBDX_ebABmdpTbs,670
79
79
  optimum/rbln/transformers/models/llava_next/modeling_llava_next.py,sha256=w_plsUOzxnhkQBhQeUqW9aJqGCvCvLtsx0XNKYjOprU,26203
80
80
  optimum/rbln/transformers/models/midm/__init__.py,sha256=UJSaErsF-z6dZERIS143WTaygffZyzEGqoQ2ZPDiM-c,855
81
- optimum/rbln/transformers/models/midm/midm_architecture.py,sha256=au6jHs7UQjthXDOrL7aqlOw7fkwM0-vkKkLGWeV1KKQ,5370
81
+ optimum/rbln/transformers/models/midm/midm_architecture.py,sha256=357iviqQkzI0s_lU_teH1sVOChNRDUABe3GA0HuhZZY,5444
82
82
  optimum/rbln/transformers/models/midm/modeling_midm.py,sha256=GG25BozEZriAL-OPFGpzOjyDtSFB-NfeiLJTDAqxe20,1734
83
83
  optimum/rbln/transformers/models/mistral/__init__.py,sha256=jpGdNtRLoV7WmuYpRGVXR27BTC8RIi_nhmvYlxuhqRc,652
84
84
  optimum/rbln/transformers/models/mistral/mistral_architecture.py,sha256=_aU8TE_tdvfo0K7QpgTlz_d0qwk4O82dl9268lPL16E,733
85
85
  optimum/rbln/transformers/models/mistral/modeling_mistral.py,sha256=7nrddoBIHf8S12LZWBUpotnvG3gND11vMQda9yYXJ-s,1560
86
86
  optimum/rbln/transformers/models/phi/__init__.py,sha256=mZLt1M7BbYEvSon5UlkniMUPa15SfjZFdw0kMSAF3VA,644
87
87
  optimum/rbln/transformers/models/phi/modeling_phi.py,sha256=j-6Pqd5rR2JE8I1pnKFlCi4nW5Dv3wZjoPWxohissoo,1516
88
- optimum/rbln/transformers/models/phi/phi_architecture.py,sha256=QQBf5tlJDYuEHy8wLRpQW9vhYV3R6kr5OLTt4ZXrwl8,4039
88
+ optimum/rbln/transformers/models/phi/phi_architecture.py,sha256=rBQjr6MOYBo1i5yLekMSR81TzYlHrHAA30kyKDdR7ww,4132
89
89
  optimum/rbln/transformers/models/qwen2/__init__.py,sha256=RAMWc21W_2I6DH9xBjeNxPECmAcTrbKhSIefq3Lass0,648
90
90
  optimum/rbln/transformers/models/qwen2/modeling_qwen2.py,sha256=9-aFDvjMzPNUyGOz0qo33RE18bUFGYZ3Wt_68zb5uJY,1530
91
91
  optimum/rbln/transformers/models/qwen2/qwen2_architecture.py,sha256=XlNAMYAcDLohnSAhIFGKOPuCB5XLgzYs5ABWdeQSaZs,720
92
92
  optimum/rbln/transformers/models/seq2seq/__init__.py,sha256=EmEMV4rOYqKyruX85d0fR73-b8N6BSD6CPcbpYdBuVk,651
93
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py,sha256=HG_-8ufRWIls67imU1547V0bk9FUWC0haOBL7eyRV6k,16365
94
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py,sha256=jPiRo2woijKd8VOHKb0qhBmy0vw4_WHQQh1JotlTx1w,18390
93
+ optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py,sha256=NPfJf9Uk_bYOae7hXGHwteGiWH0va63Z-D93RmAMENg,17611
94
+ optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py,sha256=QXIGWSu9PsKWE3WhkgmBj3VeszqXIo2MPOwcrb54Tbs,19348
95
95
  optimum/rbln/transformers/models/t5/__init__.py,sha256=1skR1RmnG62WTAP3-F5P1x-V_ReFhMyirH3u56vWwvc,675
96
- optimum/rbln/transformers/models/t5/modeling_t5.py,sha256=9AHRoGsr4eD_dIm1JA6ojafqIxd4w5Upzw3HmKOADkk,8049
97
- optimum/rbln/transformers/models/t5/t5_architecture.py,sha256=kkjErS42mW2jv5O_xL7BaKobvvqy7BGmYOowKyHakvI,7189
96
+ optimum/rbln/transformers/models/t5/modeling_t5.py,sha256=nKRR3eH1EAu1YkKvhlqGyTrJXDRd-IWB5LOeG9jrcb4,16021
97
+ optimum/rbln/transformers/models/t5/t5_architecture.py,sha256=oCdmF4eCTayAVjx3c-SVpmhrjnWE92jh79dMIYCwotY,9690
98
98
  optimum/rbln/transformers/models/wav2vec2/__init__.py,sha256=YpgA0K-vyg9veh0eL_jxauosbRpb_kpGKHvvQLBspKM,649
99
99
  optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py,sha256=JYJmV52j6cBwim4RanVJryfKnV80V96ol0A-oR6o7cg,3856
100
100
  optimum/rbln/transformers/models/whisper/__init__.py,sha256=ktnNe5ri3ycCWZ_W_voFB9y9-vgGgxS1X9s8LBRZmWc,665
@@ -114,7 +114,7 @@ optimum/rbln/utils/model_utils.py,sha256=DfD_Z2qvZHqcddXqnzTM1AN8khanj3-DXK2lJvV
114
114
  optimum/rbln/utils/runtime_utils.py,sha256=5-DYniyP59nx-mrrbi7AqA77L85b4Cm5oLpaxidSyss,3699
115
115
  optimum/rbln/utils/save_utils.py,sha256=hG5uOtYmecSXZuGTvCXsTM-SiyZpr5q3InUGCCq_jzQ,3619
116
116
  optimum/rbln/utils/submodule.py,sha256=oZoGrItB8WqY4i-K9WJPlLlcLohc1YGB9OHB8_XZw3A,4071
117
- optimum_rbln-0.7.3a1.dist-info/METADATA,sha256=59cCm0xXF4GfQ4oMuDLjOYHjkhUMqnZayBGMBstYd0Q,5300
118
- optimum_rbln-0.7.3a1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
119
- optimum_rbln-0.7.3a1.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
120
- optimum_rbln-0.7.3a1.dist-info/RECORD,,
117
+ optimum_rbln-0.7.3a2.dist-info/METADATA,sha256=C-IWumO-veJFZPHpF8wcOTOE0TCDzKU1Xk_ylaqrvPM,5300
118
+ optimum_rbln-0.7.3a2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
119
+ optimum_rbln-0.7.3a2.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
120
+ optimum_rbln-0.7.3a2.dist-info/RECORD,,