optimum-rbln 0.7.2rc2__py3-none-any.whl → 0.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. optimum/rbln/__init__.py +8 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/diffusers/__init__.py +8 -0
  4. optimum/rbln/diffusers/modeling_diffusers.py +103 -117
  5. optimum/rbln/diffusers/models/autoencoders/vq_model.py +11 -3
  6. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +15 -8
  7. optimum/rbln/diffusers/pipelines/__init__.py +8 -0
  8. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +7 -1
  9. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +25 -0
  10. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +107 -1
  11. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +25 -0
  12. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +3 -0
  13. optimum/rbln/modeling.py +4 -1
  14. optimum/rbln/modeling_base.py +16 -3
  15. optimum/rbln/ops/__init__.py +6 -2
  16. optimum/rbln/ops/attn.py +94 -85
  17. optimum/rbln/ops/flash_attn.py +46 -25
  18. optimum/rbln/ops/kv_cache_update.py +4 -4
  19. optimum/rbln/transformers/modeling_generic.py +3 -3
  20. optimum/rbln/transformers/models/bart/bart_architecture.py +10 -6
  21. optimum/rbln/transformers/models/bart/modeling_bart.py +6 -2
  22. optimum/rbln/transformers/models/bert/modeling_bert.py +1 -1
  23. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +264 -133
  24. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +276 -29
  25. optimum/rbln/transformers/models/exaone/exaone_architecture.py +11 -4
  26. optimum/rbln/transformers/models/gemma/gemma_architecture.py +11 -4
  27. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +5 -3
  28. optimum/rbln/transformers/models/midm/midm_architecture.py +5 -3
  29. optimum/rbln/transformers/models/phi/phi_architecture.py +9 -7
  30. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +50 -13
  31. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +60 -36
  32. optimum/rbln/transformers/models/t5/modeling_t5.py +3 -1
  33. optimum/rbln/transformers/models/t5/t5_architecture.py +65 -3
  34. optimum/rbln/transformers/models/whisper/whisper_architecture.py +26 -36
  35. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +1 -14
  36. optimum/rbln/utils/import_utils.py +7 -0
  37. {optimum_rbln-0.7.2rc2.dist-info → optimum_rbln-0.7.3.dist-info}/METADATA +1 -1
  38. {optimum_rbln-0.7.2rc2.dist-info → optimum_rbln-0.7.3.dist-info}/RECORD +40 -38
  39. {optimum_rbln-0.7.2rc2.dist-info → optimum_rbln-0.7.3.dist-info}/WHEEL +0 -0
  40. {optimum_rbln-0.7.2rc2.dist-info → optimum_rbln-0.7.3.dist-info}/licenses/LICENSE +0 -0
@@ -19,7 +19,12 @@ import torch
19
19
  from torch import nn
20
20
  from transformers import PretrainedConfig, PreTrainedModel
21
21
 
22
- from ....ops import register_rbln_custom_attention, register_rbln_custom_flash_attention
22
+ from ....ops import (
23
+ register_rbln_custom_paged_attention,
24
+ register_rbln_custom_paged_causal_attention,
25
+ register_rbln_custom_paged_flash_attention,
26
+ register_rbln_custom_paged_flash_causal_attention,
27
+ )
23
28
  from ....utils import logging
24
29
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
25
30
 
@@ -34,7 +39,7 @@ MAX_FLASH_ATTN_PARTITION_LENGTH = 32_768
34
39
 
35
40
 
36
41
  def validate_attention_method(
37
- rbln_attn_impl: str, rbln_kvcache_partition_len: int, rbln_max_seq_len: int
42
+ rbln_attn_impl: str, rbln_kvcache_partition_len: int, rbln_kvcache_block_size: int, rbln_max_seq_len: int
38
43
  ) -> Tuple[str, int]:
39
44
  if rbln_kvcache_partition_len is not None:
40
45
  if rbln_attn_impl == "eager":
@@ -93,7 +98,19 @@ def validate_attention_method(
93
98
  "this requirement, or consider switching `rbln_attn_impl` to 'eager' for shorter lengths."
94
99
  )
95
100
 
96
- return rbln_attn_impl, rbln_kvcache_partition_len
101
+ if rbln_kvcache_block_size is not None:
102
+ if rbln_attn_impl == "flash_attn" and rbln_kvcache_partition_len != rbln_kvcache_block_size:
103
+ raise ValueError(
104
+ f" When using 'flash attention', the `rbln_kvcache_block_size` ({rbln_kvcache_block_size}) "
105
+ f"must always be set equal to the `rbln_kvcache_partition_len` {rbln_kvcache_partition_len}."
106
+ )
107
+ elif rbln_attn_impl == "eager" and rbln_kvcache_block_size != rbln_max_seq_len:
108
+ raise ValueError(
109
+ f" When using 'eager attention', the `rbln_kvcache_block_size` ({rbln_kvcache_block_size}) "
110
+ f"must always be set equal to the `rbln_max_seq_len` {rbln_max_seq_len}."
111
+ )
112
+
113
+ return rbln_attn_impl, rbln_kvcache_partition_len, rbln_kvcache_block_size
97
114
 
98
115
 
99
116
  class DecoderOnlyWrapper(nn.Module):
@@ -102,7 +119,7 @@ class DecoderOnlyWrapper(nn.Module):
102
119
  This wrapper is designed to:
103
120
  1. Convert Huggingface decoder models for RBLN compilation with static shapes
104
121
  2. Handle input/model mapping and additional information supply (e.g., positional embeddings)
105
- 3. Manage different attention implementations (standard and flash attention)
122
+ 3. Manage different attention implementations (standard/flash attention)
106
123
  4. Support both prefill and decode phases
107
124
 
108
125
  Notes:
@@ -128,7 +145,9 @@ class DecoderOnlyWrapper(nn.Module):
128
145
  max_seq_len: int,
129
146
  use_rotary_emb: bool,
130
147
  attn_impl: str,
148
+ use_attention_mask: bool,
131
149
  kvcache_partition_len: Optional[int] = None,
150
+ kvcache_block_size: Optional[int] = None,
132
151
  ):
133
152
  super().__init__()
134
153
  self.config = causal_lm.config
@@ -139,12 +158,20 @@ class DecoderOnlyWrapper(nn.Module):
139
158
  self.rotary_emb = None
140
159
 
141
160
  self.attn_impl = attn_impl
161
+ self.kvcache_block_size = kvcache_block_size
162
+ self.use_attention_mask = use_attention_mask
142
163
  if self.attn_impl == "flash_attn":
143
164
  self.kvcache_partition_len = kvcache_partition_len or DEFAULT_FLASH_ATTN_PARTITION_LENGTH
144
- register_rbln_custom_flash_attention()
165
+ if self.use_attention_mask:
166
+ register_rbln_custom_paged_flash_attention()
167
+ else:
168
+ register_rbln_custom_paged_flash_causal_attention()
145
169
  elif self.attn_impl == "eager":
146
170
  self.kvcache_partition_len = None
147
- register_rbln_custom_attention()
171
+ if self.use_attention_mask:
172
+ register_rbln_custom_paged_attention()
173
+ else:
174
+ register_rbln_custom_paged_causal_attention()
148
175
  else:
149
176
  raise ValueError(f"Unknown attn_impl : {self.attn_impl}")
150
177
 
@@ -154,7 +181,7 @@ class DecoderOnlyWrapper(nn.Module):
154
181
  f" or equal to max_seq_len({max_seq_len})!"
155
182
  )
156
183
 
157
- self.causal_lm = self.convert_to_rbln_causal_lm(causal_lm)
184
+ self.causal_lm = self.convert_to_rbln_causal_lm(causal_lm, max_seq_len)
158
185
 
159
186
  self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or getattr(self.config, "n_layer")
160
187
  self._phase = "prefill"
@@ -162,21 +189,32 @@ class DecoderOnlyWrapper(nn.Module):
162
189
  def get_rotary_emb(self, max_seq_len):
163
190
  return RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
164
191
 
165
- def convert_to_rbln_causal_lm(self, causal_lm: PreTrainedModel):
192
+ def convert_to_rbln_causal_lm(self, causal_lm: PreTrainedModel, max_seq_len: int):
166
193
  new_layers = []
167
194
  for layer in causal_lm.model.layers:
168
195
  if self.attn_impl == "eager":
169
- new_self_attn = DecoderOnlyAttention(layer.self_attn)
196
+ new_self_attn = DecoderOnlyAttention(
197
+ layer.self_attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
198
+ )
170
199
  elif self.attn_impl == "flash_attn":
171
200
  new_self_attn = DecoderOnlyFlashAttention(
172
- layer.self_attn, kvcache_partition_len=self.kvcache_partition_len
201
+ layer.self_attn,
202
+ kvcache_partition_len=self.kvcache_partition_len,
203
+ kvcache_block_size=self.kvcache_block_size,
204
+ use_attention_mask=self.use_attention_mask,
173
205
  )
174
206
  else:
175
207
  raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
176
208
 
177
209
  new_layer = DecoderOnlyLayer(layer, new_self_attn)
178
210
  new_layers.append(new_layer)
179
- new_model = DecoderOnlyModel(causal_lm.model, new_layers, partition_len=self.kvcache_partition_len)
211
+ new_model = DecoderOnlyModel(
212
+ causal_lm.model,
213
+ new_layers,
214
+ partition_len=self.kvcache_partition_len,
215
+ max_seq_len=max_seq_len,
216
+ kvcache_block_size=self.kvcache_block_size,
217
+ )
180
218
  new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
181
219
  return new_causal_lm
182
220
 
@@ -191,23 +229,43 @@ class DecoderOnlyWrapper(nn.Module):
191
229
 
192
230
  def forward(self, *args):
193
231
  if self.phase == "decode":
194
- (
195
- input_ids_or_inputs_embeds,
196
- attention_mask,
197
- cache_position,
198
- *past_key_values,
199
- ) = args
200
- batch_position = torch.tensor(0, dtype=torch.int16)
232
+ if self.use_attention_mask:
233
+ (
234
+ input_ids_or_inputs_embeds,
235
+ cache_position,
236
+ attention_mask,
237
+ block_tables,
238
+ *past_key_values,
239
+ ) = args
240
+ else:
241
+ (
242
+ input_ids_or_inputs_embeds,
243
+ cache_position,
244
+ block_tables,
245
+ *past_key_values,
246
+ ) = args
247
+ attention_mask = None
201
248
  query_position = None
202
249
  elif self.phase == "prefill":
203
- (
204
- input_ids_or_inputs_embeds,
205
- attention_mask,
206
- cache_position,
207
- batch_position,
208
- query_position,
209
- *past_key_values,
210
- ) = args
250
+ if self.use_attention_mask:
251
+ (
252
+ input_ids_or_inputs_embeds,
253
+ cache_position,
254
+ attention_mask,
255
+ query_position,
256
+ block_tables,
257
+ *past_key_values,
258
+ ) = args
259
+ else:
260
+ (
261
+ input_ids_or_inputs_embeds,
262
+ cache_position,
263
+ query_position,
264
+ block_tables,
265
+ *past_key_values,
266
+ ) = args
267
+ attention_mask = None
268
+
211
269
  else:
212
270
  raise ValueError(f"Unknown phase: {self.phase}")
213
271
 
@@ -235,26 +293,18 @@ class DecoderOnlyWrapper(nn.Module):
235
293
  _past_key_values.append(past_key_value)
236
294
  past_key_values = _past_key_values
237
295
 
238
- logit, present_key_values = self.causal_lm(
296
+ logit = self.causal_lm(
239
297
  input_ids=input_ids,
240
298
  inputs_embeds=inputs_embeds,
241
299
  attention_mask=attention_mask,
242
300
  cache_position=cache_position,
243
- batch_position=batch_position,
244
301
  query_position=query_position,
245
302
  past_key_values=past_key_values,
246
303
  rotary_emb=self.rotary_emb,
304
+ block_tables=block_tables,
247
305
  )
248
306
 
249
- # ((key, value)) * n_layer -> [key, value] * n_layer
250
- _present_key_values = ()
251
- for i in range(self.num_hidden_layers):
252
- key_states = present_key_values[i][0]
253
- value_states = present_key_values[i][1]
254
- _present_key_values = _present_key_values + (key_states, value_states)
255
- present_key_values = _present_key_values
256
-
257
- return logit, present_key_values
307
+ return logit
258
308
 
259
309
 
260
310
  class DecoderOnlyForCausalLM(nn.Module):
@@ -301,28 +351,27 @@ class DecoderOnlyForCausalLM(nn.Module):
301
351
  inputs_embeds: torch.Tensor = None,
302
352
  attention_mask: torch.Tensor = None,
303
353
  cache_position: torch.Tensor = None,
304
- batch_position: torch.Tensor = None,
305
354
  query_position: torch.Tensor = None,
306
355
  past_key_values: Tuple[Tuple[torch.Tensor]] = None,
307
356
  rotary_emb: nn.Module = None,
357
+ block_tables: Optional[torch.Tensor] = None,
308
358
  ):
309
359
  # outputs
310
- hidden_states, present_key_values = self.model(
360
+ hidden_states = self.model(
311
361
  input_ids=input_ids,
312
362
  inputs_embeds=inputs_embeds,
313
363
  attention_mask=attention_mask,
314
364
  cache_position=cache_position,
315
- batch_position=batch_position,
316
365
  past_key_values=past_key_values,
317
366
  rotary_emb=rotary_emb,
367
+ block_tables=block_tables,
318
368
  )
319
369
 
320
370
  if self.phase == "prefill":
321
371
  hidden_states = hidden_states[:, query_position.to(torch.int).unsqueeze(0)]
322
372
 
323
373
  logits = self._original_mod.lm_head(hidden_states)
324
- output = (logits, present_key_values)
325
- return output
374
+ return logits
326
375
 
327
376
 
328
377
  class DecoderOnlyModel(nn.Module):
@@ -338,12 +387,16 @@ class DecoderOnlyModel(nn.Module):
338
387
  _phase: Current processing phase ("prefill" or "decode")
339
388
  """
340
389
 
341
- def __init__(self, model, layers: List["DecoderOnlyLayer"], partition_len=None):
390
+ def __init__(
391
+ self, model, layers: List["DecoderOnlyLayer"], partition_len=None, max_seq_len=None, kvcache_block_size=None
392
+ ):
342
393
  super().__init__()
343
394
  self._original_mod = model
344
395
  self.layers = nn.ModuleList(layers)
345
396
  self._phase = "prefill"
346
397
  self.partition_len = partition_len
398
+ self.kvcache_block_size = kvcache_block_size
399
+ self.max_seq_len = max_seq_len
347
400
 
348
401
  @property
349
402
  def phase(self):
@@ -364,9 +417,8 @@ class DecoderOnlyModel(nn.Module):
364
417
  return 1
365
418
 
366
419
  def convert_sequence_positions_for_flash_attn(self, seq_positions, max_seq_len):
367
- if self.attn_impl != "flash_attn":
420
+ if self.attn_impl not in ["flash_attn"]:
368
421
  raise NotImplementedError(f"Unknown attn_impl ({self.attn_impl}).")
369
-
370
422
  partition_len = self.partition_len
371
423
  num_partition = max_seq_len // partition_len
372
424
 
@@ -392,9 +444,9 @@ class DecoderOnlyModel(nn.Module):
392
444
  inputs_embeds: torch.Tensor = None,
393
445
  attention_mask: torch.Tensor = None,
394
446
  cache_position: torch.Tensor = None,
395
- batch_position: torch.Tensor = None,
396
447
  past_key_values: Tuple[Tuple[torch.Tensor]] = None,
397
448
  rotary_emb: nn.Module = None,
449
+ block_tables: Optional[torch.Tensor] = None,
398
450
  ):
399
451
  # retrieve input_ids and inputs_embeds
400
452
  if (input_ids is None) ^ (inputs_embeds is not None):
@@ -410,7 +462,7 @@ class DecoderOnlyModel(nn.Module):
410
462
 
411
463
  # get cos,sin vector if needed
412
464
  if rotary_emb is not None:
413
- cos, sin = rotary_emb(hidden_states, attention_mask.shape[-1]) # dtype carrier, max_seq_len
465
+ cos, sin = rotary_emb(hidden_states, self.max_seq_len) # dtype carrier, max_seq_len
414
466
  cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, cache_position)
415
467
  else:
416
468
  batch_size = inputs_embeds.shape[0]
@@ -429,27 +481,25 @@ class DecoderOnlyModel(nn.Module):
429
481
  # (batch, seq_len) -> (batch,)
430
482
  if self.attn_impl == "flash_attn":
431
483
  seq_positions = cache_position[:, 0]
432
- max_seq_len = past_key_values[0][0].shape[-2]
433
484
  seq_positions = self.convert_sequence_positions_for_flash_attn(
434
- seq_positions=seq_positions, max_seq_len=max_seq_len
485
+ seq_positions=seq_positions, max_seq_len=self.max_seq_len
435
486
  )
436
487
  else:
437
488
  seq_positions = cache_position[:, :1]
438
489
 
439
- present_key_values = past_key_values
440
490
  for layer in self.layers:
441
- hidden_states, present_key_values = layer(
491
+ hidden_states = layer(
442
492
  hidden_states=hidden_states,
443
493
  attention_mask=attention_mask,
444
494
  seq_positions=seq_positions,
445
- batch_position=batch_position,
446
- past_key_values=present_key_values,
495
+ past_key_values=past_key_values,
447
496
  cos=cos,
448
497
  sin=sin,
498
+ block_tables=block_tables,
449
499
  )
450
500
 
451
501
  hidden_states = self.get_last_layernorm()(hidden_states)
452
- return hidden_states, present_key_values
502
+ return hidden_states
453
503
 
454
504
 
455
505
  class DecoderOnlyLayer(nn.Module):
@@ -503,22 +553,22 @@ class DecoderOnlyLayer(nn.Module):
503
553
  hidden_states: torch.Tensor,
504
554
  attention_mask: torch.Tensor,
505
555
  seq_positions: torch.LongTensor,
506
- batch_position: torch.Tensor,
507
556
  past_key_values: Tuple[Tuple[torch.Tensor]],
508
557
  cos: Optional[torch.Tensor] = None,
509
558
  sin: Optional[torch.Tensor] = None,
559
+ block_tables: Optional[torch.Tensor] = None,
510
560
  ):
511
561
  residual = hidden_states
512
562
  hidden_states = self.get_pre_attention_layernorm()(hidden_states)
513
563
 
514
- hidden_states, present_key_values = self.self_attn(
564
+ hidden_states = self.self_attn(
515
565
  hidden_states=hidden_states,
516
566
  attention_mask=attention_mask,
517
567
  seq_positions=seq_positions,
518
- batch_position=batch_position,
519
568
  past_key_values=past_key_values,
520
569
  cos=cos,
521
570
  sin=sin,
571
+ block_tables=block_tables,
522
572
  )
523
573
  hidden_states = residual + hidden_states
524
574
 
@@ -528,7 +578,7 @@ class DecoderOnlyLayer(nn.Module):
528
578
  hidden_states = self._original_mod.mlp(hidden_states)
529
579
  hidden_states = residual + hidden_states
530
580
 
531
- return hidden_states, present_key_values
581
+ return hidden_states
532
582
 
533
583
 
534
584
  class DecoderOnlyAttention(nn.Module):
@@ -542,7 +592,7 @@ class DecoderOnlyAttention(nn.Module):
542
592
  self_attn: Original attention module from the base model
543
593
  """
544
594
 
545
- def __init__(self, self_attn):
595
+ def __init__(self, self_attn, use_attention_mask, kvcache_block_size):
546
596
  super().__init__()
547
597
  self._original_mod = self_attn
548
598
  self.layer_idx = self_attn.layer_idx
@@ -560,7 +610,9 @@ class DecoderOnlyAttention(nn.Module):
560
610
  else:
561
611
  self.num_key_value_heads = self.num_heads
562
612
 
613
+ self.use_attention_mask = use_attention_mask
563
614
  self.attention = self.get_attention()
615
+ self.kvcache_block_size = kvcache_block_size
564
616
  self.__post_init__()
565
617
 
566
618
  @property
@@ -573,7 +625,7 @@ class DecoderOnlyAttention(nn.Module):
573
625
  self.attention.phase = phase
574
626
 
575
627
  def get_attention(self):
576
- return AttentionOp(self.num_heads, self.head_dim, self.num_key_value_heads)
628
+ return AttentionOp(self.num_heads, self.head_dim, self.num_key_value_heads, self.use_attention_mask)
577
629
 
578
630
  def __post_init__(self):
579
631
  self.q_proj = self._original_mod.q_proj
@@ -606,10 +658,10 @@ class DecoderOnlyAttention(nn.Module):
606
658
  hidden_states: torch.Tensor,
607
659
  attention_mask: torch.Tensor,
608
660
  seq_positions: torch.LongTensor,
609
- batch_position: torch.Tensor,
610
661
  past_key_values: Tuple[Tuple[torch.Tensor]],
611
662
  cos: Optional[torch.Tensor] = None,
612
663
  sin: Optional[torch.Tensor] = None,
664
+ block_tables: Optional[torch.Tensor] = None,
613
665
  ):
614
666
  batch_size, query_length, _ = hidden_states.size()
615
667
 
@@ -628,32 +680,31 @@ class DecoderOnlyAttention(nn.Module):
628
680
  if batch_size > 1 and self.phase == "prefill":
629
681
  raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
630
682
 
631
- attn_output, key_state, value_state = self.attention(
683
+ attn_output = self.attention(
632
684
  query_states,
633
685
  key_states,
634
686
  value_states,
635
687
  attention_mask,
636
688
  past_key_state=past_key_values[self.layer_idx][0],
637
689
  past_value_state=past_key_values[self.layer_idx][1],
638
- batch_position=None if self.phase == "decode" else batch_position,
639
690
  seq_position=seq_positions,
640
691
  scale=self.scale,
692
+ block_tables=block_tables,
693
+ block_size=self.kvcache_block_size,
641
694
  )
642
- key_states = key_state
643
- value_states = value_state
644
695
 
645
696
  attn_outputs = self.o_proj(attn_output)
646
- past_key_values[self.layer_idx] = key_states, value_states
647
- return attn_outputs, past_key_values
697
+ return attn_outputs
648
698
 
649
699
 
650
700
  class AttentionOp(nn.Module):
651
- def __init__(self, num_heads: int, head_dim: int, num_key_value_heads: int):
701
+ def __init__(self, num_heads: int, head_dim: int, num_key_value_heads: int, use_attention_mask: bool):
652
702
  super().__init__()
653
703
  self.num_heads = num_heads
654
704
  self.head_dim = head_dim
655
705
  self.num_key_value_heads = num_key_value_heads
656
706
  self.phase = "prefill"
707
+ self.use_attention_mask = use_attention_mask
657
708
 
658
709
  def forward(
659
710
  self,
@@ -661,11 +712,12 @@ class AttentionOp(nn.Module):
661
712
  key_state: torch.Tensor,
662
713
  value_state: torch.Tensor,
663
714
  attn_mask: torch.Tensor,
664
- batch_position: torch.Tensor,
665
715
  past_key_state: torch.Tensor,
666
716
  past_value_state: torch.Tensor,
667
717
  seq_position: torch.Tensor,
668
718
  scale: torch.Tensor,
719
+ block_tables: torch.Tensor,
720
+ block_size: int,
669
721
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
670
722
  """Compute attention with static shapes and explicit cache management.
671
723
 
@@ -674,19 +726,19 @@ class AttentionOp(nn.Module):
674
726
  key_state: Key tensor [1, num_heads, seq_len, head_dim]
675
727
  value_state: Value tensor [1, num_heads, seq_len, head_dim]
676
728
  attn_mask: Attention mask tensor ∈ {0, 1}
677
- batch_position: Batch index for cache lookup
678
729
  past_key_state: Previous key cache states
679
730
  past_value_state: Previous value cache states
680
731
  seq_position: Current position in sequence
681
732
  scale: Scale applied to attn weights
682
733
 
683
734
  Returns:
684
- Tuple of (attention_output, key_state, value_state)
735
+ Tensor: attention_output: [batch, num_heads, seq_len, head_dim]
685
736
  """
686
737
  # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
687
738
  key_state = key_state.unsqueeze(2) # 1, 32, 1, 128, 128
688
739
  value_state = value_state.unsqueeze(2)
689
- attn_mask = attn_mask.unsqueeze(2)
740
+ if self.use_attention_mask:
741
+ attn_mask = attn_mask.unsqueeze(2)
690
742
 
691
743
  if self.phase == "decode":
692
744
  batch_size = key_state.shape[0]
@@ -702,35 +754,64 @@ class AttentionOp(nn.Module):
702
754
  )
703
755
 
704
756
  if self.phase == "decode":
705
- attn_output, key_state, value_state = torch.ops.rbln_custom_ops.attn_decode(
706
- query_state,
707
- key_state,
708
- value_state,
709
- attn_mask,
710
- past_key_state.unsqueeze(2),
711
- past_value_state.unsqueeze(2),
712
- seq_position,
713
- scale,
714
- )
757
+ if self.use_attention_mask:
758
+ attn_output = torch.ops.rbln_custom_ops.paged_attn_decode(
759
+ query_state,
760
+ key_state,
761
+ value_state,
762
+ attn_mask,
763
+ past_key_state.unsqueeze(2),
764
+ past_value_state.unsqueeze(2),
765
+ seq_position,
766
+ scale,
767
+ block_tables,
768
+ block_size,
769
+ )
770
+ else:
771
+ attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_decode(
772
+ query_state,
773
+ key_state,
774
+ value_state,
775
+ past_key_state.unsqueeze(2),
776
+ past_value_state.unsqueeze(2),
777
+ seq_position,
778
+ scale,
779
+ block_tables,
780
+ block_size,
781
+ )
715
782
 
716
783
  else:
717
- attn_output, key_state, value_state = torch.ops.rbln_custom_ops.attn_prefill(
718
- query_state,
719
- key_state,
720
- value_state,
721
- attn_mask,
722
- past_key_state.unsqueeze(2),
723
- past_value_state.unsqueeze(2),
724
- batch_position,
725
- seq_position,
726
- scale,
727
- )
784
+ if self.use_attention_mask:
785
+ attn_output = torch.ops.rbln_custom_ops.paged_attn_prefill(
786
+ query_state,
787
+ key_state,
788
+ value_state,
789
+ attn_mask,
790
+ past_key_state.unsqueeze(2),
791
+ past_value_state.unsqueeze(2),
792
+ seq_position,
793
+ scale,
794
+ block_tables,
795
+ block_size,
796
+ )
797
+ else:
798
+ attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_prefill(
799
+ query_state,
800
+ key_state,
801
+ value_state,
802
+ past_key_state.unsqueeze(2),
803
+ past_value_state.unsqueeze(2),
804
+ seq_position,
805
+ scale,
806
+ block_tables,
807
+ block_size,
808
+ )
728
809
 
729
810
  attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
730
811
  attn_output = attn_output.transpose(1, 2).contiguous()
731
812
  attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
732
813
 
733
- return attn_output, key_state.squeeze(2), value_state.squeeze(2)
814
+ return attn_output
734
815
 
735
816
 
736
817
  def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
@@ -826,22 +907,30 @@ class RotaryEmbedding(nn.Module):
826
907
 
827
908
 
828
909
  class DecoderOnlyFlashAttention(DecoderOnlyAttention):
829
- def __init__(self, self_attn, kvcache_partition_len):
910
+ def __init__(self, self_attn, kvcache_partition_len, kvcache_block_size, use_attention_mask):
830
911
  self.kvcache_partition_size = kvcache_partition_len
831
- super().__init__(self_attn=self_attn)
912
+ super().__init__(
913
+ self_attn=self_attn, use_attention_mask=use_attention_mask, kvcache_block_size=kvcache_block_size
914
+ )
832
915
 
833
916
  def get_attention(self):
834
- return FlashAttentionOp(self.num_heads, self.head_dim, self.num_key_value_heads, self.kvcache_partition_size)
917
+ return FlashAttentionOp(
918
+ self.num_heads,
919
+ self.head_dim,
920
+ self.num_key_value_heads,
921
+ self.kvcache_partition_size,
922
+ self.use_attention_mask,
923
+ )
835
924
 
836
925
  def forward(
837
926
  self,
838
927
  hidden_states: torch.Tensor,
839
928
  attention_mask: torch.Tensor,
840
929
  seq_positions: torch.LongTensor,
841
- batch_position: torch.Tensor,
842
930
  past_key_values: Tuple[Tuple[torch.Tensor]],
843
931
  cos: Optional[torch.Tensor] = None,
844
932
  sin: Optional[torch.Tensor] = None,
933
+ block_tables: Optional[torch.Tensor] = None,
845
934
  ):
846
935
  batch_size, query_length, _ = hidden_states.size()
847
936
 
@@ -857,29 +946,38 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
857
946
  if cos is not None and sin is not None:
858
947
  query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
859
948
 
860
- attn_output, key_state, value_state = self.attention(
949
+ attn_output = self.attention(
861
950
  query_states,
862
951
  key_states,
863
952
  value_states,
864
953
  attention_mask,
865
954
  past_key_state=past_key_values[self.layer_idx][0],
866
955
  past_value_state=past_key_values[self.layer_idx][1],
867
- batch_position=None if self.phase == "decode" else batch_position,
868
956
  seq_position=seq_positions,
869
957
  scale=self.scale,
958
+ block_tables=block_tables,
959
+ kvcache_block_size=self.kvcache_block_size,
870
960
  )
871
- key_states = key_state
872
- value_states = value_state
873
961
 
874
962
  attn_outputs = self.o_proj(attn_output)
875
- past_key_values[self.layer_idx] = key_states, value_states
876
-
877
- return attn_outputs, past_key_values
963
+ return attn_outputs
878
964
 
879
965
 
880
966
  class FlashAttentionOp(AttentionOp):
881
- def __init__(self, num_heads: int, head_dim: int, num_key_value_heads: int, kvcache_partition_len: int):
882
- super().__init__(num_heads=num_heads, head_dim=head_dim, num_key_value_heads=num_key_value_heads)
967
+ def __init__(
968
+ self,
969
+ num_heads: int,
970
+ head_dim: int,
971
+ num_key_value_heads: int,
972
+ kvcache_partition_len: int,
973
+ use_attention_mask: bool,
974
+ ):
975
+ super().__init__(
976
+ num_heads=num_heads,
977
+ head_dim=head_dim,
978
+ num_key_value_heads=num_key_value_heads,
979
+ use_attention_mask=use_attention_mask,
980
+ )
883
981
  self.kvcache_partition_size = kvcache_partition_len
884
982
 
885
983
  def forward(
@@ -888,16 +986,18 @@ class FlashAttentionOp(AttentionOp):
888
986
  key_state,
889
987
  value_state,
890
988
  attn_mask,
891
- batch_position,
892
989
  past_key_state,
893
990
  past_value_state,
894
991
  seq_position,
895
992
  scale,
993
+ block_tables,
994
+ kvcache_block_size,
896
995
  ):
897
996
  # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
898
997
  key_state = key_state.unsqueeze(2)
899
998
  value_state = value_state.unsqueeze(2)
900
- attn_mask = attn_mask.unsqueeze(2)
999
+ if self.use_attention_mask:
1000
+ attn_mask = attn_mask.unsqueeze(2)
901
1001
 
902
1002
  if self.phase == "decode":
903
1003
  batch_size = key_state.shape[0]
@@ -913,34 +1013,65 @@ class FlashAttentionOp(AttentionOp):
913
1013
  )
914
1014
 
915
1015
  if self.phase == "decode":
916
- attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_attn_decode(
917
- query_state,
918
- key_state,
919
- value_state,
920
- attn_mask,
921
- past_key_state.unsqueeze(2),
922
- past_value_state.unsqueeze(2),
923
- seq_position,
924
- scale,
925
- self.kvcache_partition_size,
926
- )
1016
+ if self.use_attention_mask:
1017
+ attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_decode(
1018
+ query_state,
1019
+ key_state,
1020
+ value_state,
1021
+ attn_mask,
1022
+ past_key_state.unsqueeze(2),
1023
+ past_value_state.unsqueeze(2),
1024
+ seq_position,
1025
+ scale,
1026
+ block_tables,
1027
+ kvcache_block_size,
1028
+ self.kvcache_partition_size,
1029
+ )
1030
+ else:
1031
+ attn_output = torch.ops.rbln_custom_ops.paged_flash_causal_attn_decode(
1032
+ query_state,
1033
+ key_state,
1034
+ value_state,
1035
+ past_key_state.unsqueeze(2),
1036
+ past_value_state.unsqueeze(2),
1037
+ seq_position,
1038
+ scale,
1039
+ block_tables,
1040
+ kvcache_block_size,
1041
+ self.kvcache_partition_size,
1042
+ )
927
1043
  else:
928
- attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_attn_prefill(
929
- query_state,
930
- key_state,
931
- value_state,
932
- attn_mask,
933
- past_key_state.unsqueeze(2),
934
- past_value_state.unsqueeze(2),
935
- batch_position,
936
- seq_position,
937
- scale,
938
- self.kvcache_partition_size,
939
- )
1044
+ if self.use_attention_mask:
1045
+ attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_prefill(
1046
+ query_state,
1047
+ key_state,
1048
+ value_state,
1049
+ attn_mask,
1050
+ past_key_state.unsqueeze(2),
1051
+ past_value_state.unsqueeze(2),
1052
+ seq_position,
1053
+ scale,
1054
+ block_tables,
1055
+ kvcache_block_size,
1056
+ self.kvcache_partition_size,
1057
+ )
1058
+ else:
1059
+ attn_output = torch.ops.rbln_custom_ops.paged_flash_causal_attn_prefill(
1060
+ query_state,
1061
+ key_state,
1062
+ value_state,
1063
+ past_key_state.unsqueeze(2),
1064
+ past_value_state.unsqueeze(2),
1065
+ seq_position,
1066
+ scale,
1067
+ block_tables,
1068
+ kvcache_block_size,
1069
+ self.kvcache_partition_size,
1070
+ )
940
1071
 
941
1072
  # reshape for removing repeat_kv
942
1073
  attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
943
1074
  attn_output = attn_output.transpose(1, 2).contiguous()
944
1075
  attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
945
1076
 
946
- return attn_output, key_state, value_state
1077
+ return attn_output