optimum-rbln 0.7.5a0__py3-none-any.whl → 0.7.5rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. optimum/rbln/__init__.py +30 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +9 -4
  4. optimum/rbln/modeling.py +7 -5
  5. optimum/rbln/ops/__init__.py +1 -0
  6. optimum/rbln/ops/attn.py +10 -0
  7. optimum/rbln/ops/flash_attn.py +8 -0
  8. optimum/rbln/ops/sliding_window_attn.py +111 -0
  9. optimum/rbln/transformers/__init__.py +32 -3
  10. optimum/rbln/transformers/models/__init__.py +37 -0
  11. optimum/rbln/transformers/models/auto/__init__.py +1 -0
  12. optimum/rbln/transformers/models/auto/modeling_auto.py +7 -0
  13. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  14. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +93 -0
  15. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +298 -0
  16. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +12 -6
  17. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +189 -90
  18. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +186 -95
  19. optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
  20. optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
  21. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  22. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +69 -0
  23. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +446 -0
  24. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +1057 -0
  25. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  26. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +11 -7
  27. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +4 -4
  28. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  29. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  30. optimum/rbln/transformers/models/opt/configuration_opt.py +19 -0
  31. optimum/rbln/transformers/models/opt/modeling_opt.py +80 -0
  32. optimum/rbln/transformers/models/opt/opt_architecture.py +77 -0
  33. optimum/rbln/transformers/models/phi/phi_architecture.py +4 -1
  34. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -11
  35. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +35 -52
  36. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -0
  37. optimum/rbln/transformers/models/siglip/__init__.py +20 -0
  38. optimum/rbln/transformers/models/siglip/configuration_siglip.py +66 -0
  39. optimum/rbln/transformers/models/siglip/modeling_siglip.py +146 -0
  40. optimum/rbln/transformers/models/whisper/whisper_architecture.py +1 -0
  41. optimum/rbln/transformers/utils/rbln_quantization.py +121 -72
  42. optimum/rbln/utils/submodule.py +13 -1
  43. {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5rc0.dist-info}/METADATA +1 -1
  44. {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5rc0.dist-info}/RECORD +46 -31
  45. {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5rc0.dist-info}/WHEEL +0 -0
  46. {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5rc0.dist-info}/licenses/LICENSE +0 -0
@@ -146,7 +146,10 @@ class DecoderOnlyWrapper(nn.Module):
146
146
  max_seq_len: int,
147
147
  use_rotary_emb: bool,
148
148
  attn_impl: str,
149
+ use_inputs_embeds: bool,
149
150
  use_attention_mask: bool,
151
+ use_position_ids: bool,
152
+ use_learned_pos_emb: Optional[bool] = None,
150
153
  kvcache_partition_len: Optional[int] = None,
151
154
  kvcache_block_size: Optional[int] = None,
152
155
  ):
@@ -154,13 +157,21 @@ class DecoderOnlyWrapper(nn.Module):
154
157
  self.config = causal_lm.config
155
158
 
156
159
  if use_rotary_emb:
157
- self.rotary_emb = self.get_rotary_emb(max_seq_len=max_seq_len)
160
+ rotary_embs = self.get_rotary_emb(max_seq_len=max_seq_len)
161
+ if isinstance(rotary_embs, tuple):
162
+ self.rotary_emb_global, self.rotary_emb_local = rotary_embs
163
+ else:
164
+ self.rotary_emb = rotary_embs
158
165
  else:
159
166
  self.rotary_emb = None
160
167
 
161
168
  self.attn_impl = attn_impl
162
169
  self.kvcache_block_size = kvcache_block_size
163
170
  self.use_attention_mask = use_attention_mask
171
+ self.use_position_ids = use_position_ids
172
+ self.use_inputs_embeds = use_inputs_embeds
173
+ self.use_learned_pos_emb = use_learned_pos_emb
174
+
164
175
  if self.attn_impl == "flash_attn":
165
176
  self.kvcache_partition_len = kvcache_partition_len or DEFAULT_FLASH_ATTN_PARTITION_LENGTH
166
177
  elif self.attn_impl == "eager":
@@ -188,7 +199,10 @@ class DecoderOnlyWrapper(nn.Module):
188
199
  for layer in causal_lm.model.layers:
189
200
  if self.attn_impl == "eager":
190
201
  new_self_attn = DecoderOnlyAttention(
191
- layer.self_attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
202
+ layer.self_attn,
203
+ self.use_attention_mask,
204
+ self.use_position_ids,
205
+ kvcache_block_size=self.kvcache_block_size,
192
206
  )
193
207
  elif self.attn_impl == "flash_attn":
194
208
  new_self_attn = DecoderOnlyFlashAttention(
@@ -196,6 +210,7 @@ class DecoderOnlyWrapper(nn.Module):
196
210
  kvcache_partition_len=self.kvcache_partition_len,
197
211
  kvcache_block_size=self.kvcache_block_size,
198
212
  use_attention_mask=self.use_attention_mask,
213
+ use_position_ids=self.use_position_ids,
199
214
  )
200
215
  else:
201
216
  raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
@@ -209,6 +224,7 @@ class DecoderOnlyWrapper(nn.Module):
209
224
  partition_len=self.kvcache_partition_len,
210
225
  max_seq_len=max_seq_len,
211
226
  kvcache_block_size=self.kvcache_block_size,
227
+ use_learned_pos_emb=self.use_learned_pos_emb,
212
228
  )
213
229
  new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
214
230
  return new_causal_lm
@@ -222,24 +238,16 @@ class DecoderOnlyWrapper(nn.Module):
222
238
  self._phase = phase
223
239
  self.causal_lm.phase = phase
224
240
 
225
- def forward_common(
226
- self,
227
- input_ids_or_inputs_embeds: torch.Tensor,
228
- cache_position: torch.Tensor,
229
- attention_mask: torch.Tensor,
230
- query_position: torch.Tensor,
231
- block_tables: torch.Tensor,
232
- rotary_emb: Union[nn.Module, torch.Tensor],
233
- *past_key_values: List[torch.Tensor],
234
- ):
235
- if input_ids_or_inputs_embeds.ndim == 2:
236
- input_ids = input_ids_or_inputs_embeds
237
- inputs_embeds = None
238
- elif input_ids_or_inputs_embeds.ndim == 3:
239
- input_ids = None
240
- inputs_embeds = input_ids_or_inputs_embeds
241
- else:
242
- raise NotImplementedError(f"Unknown ndim of input : {input_ids_or_inputs_embeds.ndim}")
241
+ def prepare_forward_args(self, *args):
242
+ args = list(args)
243
+ input_ids = None if self.use_inputs_embeds else args.pop(0)
244
+ inputs_embeds = args.pop(0) if self.use_inputs_embeds else None
245
+ cache_position = args.pop(0)
246
+ block_tables = args.pop(0)
247
+ query_position = args.pop(0) if self.phase == "prefill" else None
248
+ attention_mask = args.pop(0) if self.use_attention_mask else None
249
+ position_ids = args.pop(0) if self.use_position_ids else None
250
+ past_key_values = args
243
251
 
244
252
  if len(past_key_values) != 2 * self.num_hidden_layers:
245
253
  raise ValueError(
@@ -256,11 +264,37 @@ class DecoderOnlyWrapper(nn.Module):
256
264
  _past_key_values.append(past_key_value)
257
265
  past_key_values = _past_key_values
258
266
 
267
+ return (
268
+ input_ids,
269
+ inputs_embeds,
270
+ cache_position,
271
+ block_tables,
272
+ query_position,
273
+ attention_mask,
274
+ position_ids,
275
+ past_key_values,
276
+ self.rotary_emb,
277
+ )
278
+
279
+ def forward(self, *args):
280
+ (
281
+ input_ids,
282
+ inputs_embeds,
283
+ cache_position,
284
+ block_tables,
285
+ query_position,
286
+ attention_mask,
287
+ position_ids,
288
+ past_key_values,
289
+ rotary_emb,
290
+ ) = self.prepare_forward_args(*args)
291
+
259
292
  logit = self.causal_lm(
260
293
  input_ids=input_ids,
261
294
  inputs_embeds=inputs_embeds,
262
295
  attention_mask=attention_mask,
263
296
  cache_position=cache_position,
297
+ position_ids=position_ids,
264
298
  query_position=query_position,
265
299
  past_key_values=past_key_values,
266
300
  rotary_emb=rotary_emb,
@@ -269,58 +303,6 @@ class DecoderOnlyWrapper(nn.Module):
269
303
 
270
304
  return logit
271
305
 
272
- def forward(self, *args):
273
- if self.phase == "decode":
274
- if self.use_attention_mask:
275
- (
276
- input_ids_or_inputs_embeds,
277
- cache_position,
278
- attention_mask,
279
- block_tables,
280
- *past_key_values,
281
- ) = args
282
- else:
283
- (
284
- input_ids_or_inputs_embeds,
285
- cache_position,
286
- block_tables,
287
- *past_key_values,
288
- ) = args
289
- attention_mask = None
290
- query_position = None
291
- elif self.phase == "prefill":
292
- if self.use_attention_mask:
293
- (
294
- input_ids_or_inputs_embeds,
295
- cache_position,
296
- attention_mask,
297
- query_position,
298
- block_tables,
299
- *past_key_values,
300
- ) = args
301
- else:
302
- (
303
- input_ids_or_inputs_embeds,
304
- cache_position,
305
- query_position,
306
- block_tables,
307
- *past_key_values,
308
- ) = args
309
- attention_mask = None
310
-
311
- else:
312
- raise ValueError(f"Unknown phase: {self.phase}")
313
-
314
- return self.forward_common(
315
- input_ids_or_inputs_embeds,
316
- cache_position,
317
- attention_mask,
318
- query_position,
319
- block_tables,
320
- self.rotary_emb,
321
- *past_key_values,
322
- )
323
-
324
306
 
325
307
  class DecoderOnlyForCausalLM(nn.Module):
326
308
  """A specialized wrapper for Causal Language Models optimized for RBLN compilation.
@@ -367,6 +349,7 @@ class DecoderOnlyForCausalLM(nn.Module):
367
349
  inputs_embeds: torch.Tensor = None,
368
350
  attention_mask: torch.Tensor = None,
369
351
  cache_position: torch.Tensor = None,
352
+ position_ids: torch.Tensor = None,
370
353
  query_position: torch.Tensor = None,
371
354
  past_key_values: Tuple[Tuple[torch.Tensor]] = None,
372
355
  rotary_emb: nn.Module = None,
@@ -378,6 +361,7 @@ class DecoderOnlyForCausalLM(nn.Module):
378
361
  inputs_embeds=inputs_embeds,
379
362
  attention_mask=attention_mask,
380
363
  cache_position=cache_position,
364
+ position_ids=position_ids,
381
365
  past_key_values=past_key_values,
382
366
  rotary_emb=rotary_emb,
383
367
  block_tables=block_tables,
@@ -387,6 +371,13 @@ class DecoderOnlyForCausalLM(nn.Module):
387
371
  hidden_states = hidden_states[:, query_position.to(torch.int).unsqueeze(0)]
388
372
 
389
373
  logits = self.lm_head(hidden_states)
374
+
375
+ # Apply final logit softmaxing if configured, e.g. for Gemma2
376
+ if getattr(self.config, "final_logit_softcapping", None) is not None:
377
+ logits = logits / self.config.final_logit_softcapping
378
+ logits = torch.tanh(logits)
379
+ logits = logits * self.config.final_logit_softcapping
380
+
390
381
  return logits
391
382
 
392
383
 
@@ -404,7 +395,13 @@ class DecoderOnlyModel(nn.Module):
404
395
  """
405
396
 
406
397
  def __init__(
407
- self, model, layers: List["DecoderOnlyLayer"], partition_len=None, max_seq_len=None, kvcache_block_size=None
398
+ self,
399
+ model,
400
+ layers: List["DecoderOnlyLayer"],
401
+ partition_len=None,
402
+ max_seq_len=None,
403
+ kvcache_block_size=None,
404
+ use_learned_pos_emb=None,
408
405
  ):
409
406
  super().__init__()
410
407
  self._original_mod = model
@@ -413,6 +410,7 @@ class DecoderOnlyModel(nn.Module):
413
410
  self.partition_len = partition_len
414
411
  self.kvcache_block_size = kvcache_block_size
415
412
  self.max_seq_len = max_seq_len
413
+ self.use_learned_pos_emb = use_learned_pos_emb
416
414
 
417
415
  @property
418
416
  def phase(self):
@@ -457,11 +455,12 @@ class DecoderOnlyModel(nn.Module):
457
455
  def forward(
458
456
  self,
459
457
  input_ids: torch.Tensor = None,
460
- inputs_embeds: torch.Tensor = None,
458
+ inputs_embeds: Optional[torch.Tensor] = None,
461
459
  attention_mask: torch.Tensor = None,
462
460
  cache_position: torch.Tensor = None,
461
+ position_ids: torch.Tensor = None,
463
462
  past_key_values: Tuple[Tuple[torch.Tensor]] = None,
464
- rotary_emb: nn.Module = None,
463
+ rotary_emb: Optional[Union[nn.Module, torch.Tensor]] = None,
465
464
  block_tables: Optional[torch.Tensor] = None,
466
465
  ):
467
466
  # retrieve input_ids and inputs_embeds
@@ -477,24 +476,38 @@ class DecoderOnlyModel(nn.Module):
477
476
  hidden_states = inputs_embeds * self.hidden_multiplier
478
477
 
479
478
  # get cos,sin vector if needed
479
+ position_ids = position_ids if position_ids is not None else cache_position
480
480
  if rotary_emb is not None:
481
481
  if isinstance(rotary_emb, torch.Tensor):
482
482
  cos = rotary_emb[0]
483
483
  sin = rotary_emb[1]
484
484
  else:
485
485
  cos, sin = rotary_emb(hidden_states, self.max_seq_len) # dtype carrier, max_seq_len
486
- cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, cache_position)
486
+ cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
487
+
488
+ elif self.use_learned_pos_emb:
489
+ batch_size = inputs_embeds.shape[0]
490
+ hidden_all = []
491
+ for i in range(batch_size):
492
+ positions_idx = position_ids[i]
493
+ position_weight = self.get_pos_embedding().weight[2:]
494
+ position = position_weight[positions_idx]
495
+ batch_hidden = position + inputs_embeds[i]
496
+ hidden_all.append(batch_hidden)
497
+ hidden_states = torch.stack(hidden_all, dim=0)
498
+ cos, sin = None, None
499
+
487
500
  else:
488
501
  batch_size = inputs_embeds.shape[0]
489
- if cache_position.shape[0] > 1:
502
+ if position_ids.shape[0] > 1:
490
503
  position_embeds = []
491
504
  for b_idx in range(batch_size):
492
- position_embed = self.get_pos_embedding()(cache_position[b_idx])
505
+ position_embed = self.get_pos_embedding()(position_ids[b_idx])
493
506
  position_embeds.append(position_embed)
494
507
 
495
508
  position_embeds = torch.cat(position_embeds, dim=0).unsqueeze(1)
496
509
  else:
497
- position_embeds = self.get_pos_embedding()(cache_position)
510
+ position_embeds = self.get_pos_embedding()(position_ids)
498
511
  hidden_states = hidden_states + position_embeds
499
512
  cos, sin = None, None
500
513
 
@@ -612,7 +625,7 @@ class DecoderOnlyAttention(nn.Module):
612
625
  self_attn: Original attention module from the base model
613
626
  """
614
627
 
615
- def __init__(self, self_attn, use_attention_mask, kvcache_block_size):
628
+ def __init__(self, self_attn, use_attention_mask, use_position_ids, kvcache_block_size):
616
629
  super().__init__()
617
630
  self._original_mod = self_attn
618
631
  self.layer_idx = self_attn.layer_idx
@@ -631,6 +644,7 @@ class DecoderOnlyAttention(nn.Module):
631
644
  self.num_key_value_heads = self.num_heads
632
645
 
633
646
  self.use_attention_mask = use_attention_mask
647
+ self.use_position_ids = use_position_ids
634
648
  self.attention = self.get_attention()
635
649
  self.kvcache_block_size = kvcache_block_size
636
650
  self.__post_init__()
@@ -645,7 +659,9 @@ class DecoderOnlyAttention(nn.Module):
645
659
  self.attention.phase = phase
646
660
 
647
661
  def get_attention(self):
648
- return AttentionOp(self.num_heads, self.head_dim, self.num_key_value_heads, self.use_attention_mask)
662
+ return AttentionOp(
663
+ self.num_heads, self.head_dim, self.num_key_value_heads, self.use_attention_mask, self.use_position_ids
664
+ )
649
665
 
650
666
  def __post_init__(self):
651
667
  self.q_proj = self._original_mod.q_proj
@@ -718,13 +734,16 @@ class DecoderOnlyAttention(nn.Module):
718
734
 
719
735
 
720
736
  class AttentionOp(nn.Module):
721
- def __init__(self, num_heads: int, head_dim: int, num_key_value_heads: int, use_attention_mask: bool):
737
+ def __init__(
738
+ self, num_heads: int, head_dim: int, num_key_value_heads: int, use_attention_mask: bool, use_position_ids: bool
739
+ ):
722
740
  super().__init__()
723
741
  self.num_heads = num_heads
724
742
  self.head_dim = head_dim
725
743
  self.num_key_value_heads = num_key_value_heads
726
744
  self.phase = "prefill"
727
745
  self.use_attention_mask = use_attention_mask
746
+ self.use_position_ids = use_position_ids
728
747
 
729
748
  def forward(
730
749
  self,
@@ -757,7 +776,8 @@ class AttentionOp(nn.Module):
757
776
  # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
758
777
  key_state = key_state.unsqueeze(2) # 1, 32, 1, 128, 128
759
778
  value_state = value_state.unsqueeze(2)
760
- if self.use_attention_mask:
779
+
780
+ if self.use_attention_mask and not self.use_position_ids:
761
781
  attn_mask = attn_mask.unsqueeze(2)
762
782
 
763
783
  if self.phase == "decode":
@@ -774,7 +794,7 @@ class AttentionOp(nn.Module):
774
794
  )
775
795
 
776
796
  if self.phase == "decode":
777
- if self.use_attention_mask:
797
+ if self.use_attention_mask and not self.use_position_ids:
778
798
  attn_output = torch.ops.rbln_custom_ops.paged_attn_decode(
779
799
  q=query_state,
780
800
  k=key_state,
@@ -798,10 +818,11 @@ class AttentionOp(nn.Module):
798
818
  scale=scale,
799
819
  block_table=block_tables,
800
820
  block_size=block_size,
821
+ mask=attn_mask if self.use_position_ids else None,
801
822
  )
802
823
 
803
824
  else:
804
- if self.use_attention_mask:
825
+ if self.use_attention_mask and not self.use_position_ids:
805
826
  attn_output = torch.ops.rbln_custom_ops.paged_attn_prefill(
806
827
  q=query_state,
807
828
  k=key_state,
@@ -825,6 +846,8 @@ class AttentionOp(nn.Module):
825
846
  scale=scale,
826
847
  block_table=block_tables,
827
848
  block_size=block_size,
849
+ is_bidirectional=True if self.phase == "image_prefill" else False, # FIXME, Hard-coded for Gemma3.
850
+ mask=attn_mask if self.use_position_ids else None,
828
851
  )
829
852
 
830
853
  attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
@@ -926,10 +949,13 @@ class RotaryEmbedding(nn.Module):
926
949
 
927
950
 
928
951
  class DecoderOnlyFlashAttention(DecoderOnlyAttention):
929
- def __init__(self, self_attn, kvcache_partition_len, kvcache_block_size, use_attention_mask):
952
+ def __init__(self, self_attn, kvcache_partition_len, kvcache_block_size, use_attention_mask, use_position_ids):
930
953
  self.kvcache_partition_size = kvcache_partition_len
931
954
  super().__init__(
932
- self_attn=self_attn, use_attention_mask=use_attention_mask, kvcache_block_size=kvcache_block_size
955
+ self_attn=self_attn,
956
+ use_attention_mask=use_attention_mask,
957
+ use_position_ids=use_position_ids,
958
+ kvcache_block_size=kvcache_block_size,
933
959
  )
934
960
 
935
961
  def get_attention(self):
@@ -939,6 +965,7 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
939
965
  self.num_key_value_heads,
940
966
  self.kvcache_partition_size,
941
967
  self.use_attention_mask,
968
+ self.use_position_ids,
942
969
  )
943
970
 
944
971
  def forward(
@@ -990,12 +1017,14 @@ class FlashAttentionOp(AttentionOp):
990
1017
  num_key_value_heads: int,
991
1018
  kvcache_partition_len: int,
992
1019
  use_attention_mask: bool,
1020
+ use_position_ids: bool,
993
1021
  ):
994
1022
  super().__init__(
995
1023
  num_heads=num_heads,
996
1024
  head_dim=head_dim,
997
1025
  num_key_value_heads=num_key_value_heads,
998
1026
  use_attention_mask=use_attention_mask,
1027
+ use_position_ids=use_position_ids,
999
1028
  )
1000
1029
  self.kvcache_partition_size = kvcache_partition_len
1001
1030
 
@@ -1015,7 +1044,7 @@ class FlashAttentionOp(AttentionOp):
1015
1044
  # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
1016
1045
  key_state = key_state.unsqueeze(2)
1017
1046
  value_state = value_state.unsqueeze(2)
1018
- if self.use_attention_mask:
1047
+ if self.use_attention_mask and not self.use_position_ids:
1019
1048
  attn_mask = attn_mask.unsqueeze(2)
1020
1049
 
1021
1050
  if self.phase == "decode":
@@ -1032,7 +1061,7 @@ class FlashAttentionOp(AttentionOp):
1032
1061
  )
1033
1062
 
1034
1063
  if self.phase == "decode":
1035
- if self.use_attention_mask:
1064
+ if self.use_attention_mask and not self.use_position_ids:
1036
1065
  attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_decode(
1037
1066
  q=query_state,
1038
1067
  k=key_state,
@@ -1058,9 +1087,10 @@ class FlashAttentionOp(AttentionOp):
1058
1087
  block_table=block_tables,
1059
1088
  block_size=kvcache_block_size,
1060
1089
  partition=self.kvcache_partition_size,
1090
+ mask=attn_mask if self.use_position_ids else None,
1061
1091
  )
1062
1092
  else:
1063
- if self.use_attention_mask:
1093
+ if self.use_attention_mask and not self.use_position_ids:
1064
1094
  attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_prefill(
1065
1095
  q=query_state,
1066
1096
  k=key_state,
@@ -1086,6 +1116,8 @@ class FlashAttentionOp(AttentionOp):
1086
1116
  block_table=block_tables,
1087
1117
  block_size=kvcache_block_size,
1088
1118
  partition=self.kvcache_partition_size,
1119
+ is_bidirectional=True if self.phase == "image_prefill" else False,
1120
+ mask=attn_mask if self.use_position_ids else None,
1089
1121
  )
1090
1122
 
1091
1123
  # reshape for removing repeat_kv
@@ -1094,3 +1126,70 @@ class FlashAttentionOp(AttentionOp):
1094
1126
  attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
1095
1127
 
1096
1128
  return attn_output
1129
+
1130
+
1131
+ class SlidingWindowAttentionOp(AttentionOp):
1132
+ def forward(
1133
+ self,
1134
+ query_state: torch.Tensor,
1135
+ key_state: torch.Tensor,
1136
+ value_state: torch.Tensor,
1137
+ attn_mask: torch.Tensor,
1138
+ past_key_state: torch.Tensor,
1139
+ past_value_state: torch.Tensor,
1140
+ seq_position: Tuple[torch.Tensor],
1141
+ scale: torch.Tensor,
1142
+ block_tables: torch.Tensor,
1143
+ block_size: int,
1144
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1145
+ # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
1146
+ key_state = key_state.unsqueeze(2)
1147
+ value_state = value_state.unsqueeze(2)
1148
+
1149
+ if self.phase == "decode":
1150
+ batch_size = key_state.shape[0]
1151
+ else:
1152
+ batch_size = 1
1153
+
1154
+ query_state = query_state.view(
1155
+ batch_size,
1156
+ self.num_key_value_heads,
1157
+ self.num_heads // self.num_key_value_heads,
1158
+ -1, # seq len
1159
+ self.head_dim,
1160
+ )
1161
+
1162
+ if self.phase == "decode":
1163
+ attn_output = torch.ops.rbln_custom_ops.paged_sliding_window_attn_decode(
1164
+ q=query_state,
1165
+ k=key_state,
1166
+ v=value_state,
1167
+ kcache=past_key_state.unsqueeze(2),
1168
+ vcache=past_value_state.unsqueeze(2),
1169
+ cache_seq_len=seq_position[0],
1170
+ cache_offset=seq_position[1],
1171
+ scale=scale,
1172
+ block_table=block_tables,
1173
+ block_size=block_size,
1174
+ )
1175
+ else:
1176
+ attn_output = torch.ops.rbln_custom_ops.paged_sliding_window_attn_prefill(
1177
+ q=query_state,
1178
+ k=key_state,
1179
+ v=value_state,
1180
+ kcache=past_key_state.unsqueeze(2),
1181
+ vcache=past_value_state.unsqueeze(2),
1182
+ cache_seq_len=seq_position[0],
1183
+ cache_offset=seq_position[1],
1184
+ scale=scale,
1185
+ block_table=block_tables,
1186
+ block_size=block_size,
1187
+ is_bidirectional=True if self.phase == "image_prefill" else False,
1188
+ )
1189
+
1190
+ # reshape for removing repeat_kv
1191
+ attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
1192
+ attn_output = attn_output.transpose(1, 2).contiguous()
1193
+ attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
1194
+
1195
+ return attn_output