optimum-rbln 0.8.2a7__py3-none-any.whl → 0.8.3a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (90) hide show
  1. optimum/rbln/__init__.py +8 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/configuration_utils.py +4 -4
  4. optimum/rbln/diffusers/__init__.py +1 -0
  5. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +2 -2
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +2 -2
  7. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +2 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +2 -2
  9. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +2 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +2 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +2 -2
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +2 -2
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +3 -3
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +2 -2
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +4 -4
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +2 -2
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +2 -2
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +2 -2
  19. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  20. optimum/rbln/diffusers/models/__init__.py +3 -13
  21. optimum/rbln/diffusers/pipelines/__init__.py +1 -5
  22. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +11 -6
  23. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  24. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  25. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -1
  26. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  27. optimum/rbln/modeling.py +2 -2
  28. optimum/rbln/modeling_base.py +12 -4
  29. optimum/rbln/ops/attn.py +158 -0
  30. optimum/rbln/ops/flash_attn.py +166 -0
  31. optimum/rbln/transformers/__init__.py +6 -0
  32. optimum/rbln/transformers/configuration_generic.py +4 -4
  33. optimum/rbln/transformers/modeling_generic.py +1 -4
  34. optimum/rbln/transformers/modeling_outputs.py +37 -0
  35. optimum/rbln/transformers/models/__init__.py +10 -16
  36. optimum/rbln/transformers/models/auto/__init__.py +1 -0
  37. optimum/rbln/transformers/models/auto/modeling_auto.py +7 -0
  38. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  39. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  40. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
  41. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -5
  42. optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
  43. optimum/rbln/transformers/models/colpali/colpali_architecture.py +1 -4
  44. optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
  45. optimum/rbln/transformers/models/colpali/modeling_colpali.py +2 -10
  46. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +43 -174
  47. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -93
  48. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +450 -0
  49. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +88 -0
  50. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +297 -987
  51. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  52. optimum/rbln/transformers/models/gemma/modeling_gemma.py +9 -0
  53. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +14 -3
  54. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +217 -0
  55. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +58 -257
  56. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +2 -0
  57. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +2 -2
  58. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -9
  59. optimum/rbln/transformers/models/llama/modeling_llama.py +12 -3
  60. optimum/rbln/transformers/models/llava/configuration_llava.py +2 -2
  61. optimum/rbln/transformers/models/llava/modeling_llava.py +53 -14
  62. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
  63. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +6 -16
  64. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -30
  65. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +4 -0
  66. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +2 -0
  67. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +1 -3
  68. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +2 -2
  69. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +1 -4
  70. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +3 -3
  71. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +6 -15
  72. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +4 -7
  73. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +77 -3
  74. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -4
  75. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +19 -2
  76. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +20 -1
  77. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  78. optimum/rbln/transformers/models/siglip/modeling_siglip.py +2 -2
  79. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  80. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +2 -2
  81. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
  82. optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -2
  83. optimum/rbln/transformers/models/whisper/modeling_whisper.py +20 -1
  84. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  85. optimum/rbln/transformers/utils/rbln_quantization.py +249 -46
  86. optimum/rbln/utils/runtime_utils.py +3 -3
  87. {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3a0.dist-info}/METADATA +1 -1
  88. {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3a0.dist-info}/RECORD +90 -86
  89. {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3a0.dist-info}/WHEEL +0 -0
  90. {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3a0.dist-info}/licenses/LICENSE +0 -0
@@ -13,16 +13,19 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import math
16
- from typing import List, Optional, Tuple, Union
16
+ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
17
17
 
18
18
  import torch
19
19
  from torch import nn
20
20
  from transformers import PretrainedConfig, PreTrainedModel
21
21
 
22
22
  from ....utils import logging
23
- from ...modeling_attention_utils import DEFAULT_FLASH_ATTN_PARTITION_LENGTH
24
23
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
25
- from .configuration_decoderonly import CacheImplType
24
+ from ...utils.rbln_quantization import RBLNQuantizationConfig
25
+
26
+
27
+ if TYPE_CHECKING:
28
+ from .configuration_decoderonly import RBLNDecoderOnlyModelConfig
26
29
 
27
30
 
28
31
  logger = logging.get_logger(__name__)
@@ -42,16 +45,9 @@ class DecoderOnlyWrapper(nn.Module):
42
45
  - Wrapper should not contain neural network graph operations (including memory view handling)
43
46
 
44
47
  Args:
45
- causal_lm (PreTrainedModel): The Huggingface causal language model to wrap
46
- max_seq_len (int): Maximum sequence length for position embeddings and cache sizes
48
+ model (PreTrainedModel): The Huggingface causal language model to wrap
49
+ rbln_config: The RBLN model configuration containing all necessary parameters
47
50
  use_rotary_emb (bool): Whether to use rotary position embeddings
48
- attn_impl (str): The attention implementation to use.
49
- - "eager": Uses the standard attention.
50
- - "flash_attn": Uses flash attention. When set,
51
- the key/value cache is partitioned into chunks of length
52
- `kvcache_partition_len`.
53
- kvcache_partition_len (Optional[int]): Length of KV cache partitions for flash attention.
54
- This is only relevant if `attn_impl` is set to "flash_attn`
55
51
  """
56
52
 
57
53
  _use_learned_pos_emb = False
@@ -59,24 +55,17 @@ class DecoderOnlyWrapper(nn.Module):
59
55
  def __init__(
60
56
  self,
61
57
  model: PreTrainedModel,
62
- max_seq_len: int,
58
+ rbln_config: "RBLNDecoderOnlyModelConfig",
63
59
  use_rotary_emb: bool,
64
- attn_impl: str,
65
- cache_impl: CacheImplType,
66
- use_inputs_embeds: bool,
67
- use_attention_mask: bool,
68
- use_position_ids: bool,
69
- kvcache_partition_len: Optional[int] = None,
70
- kvcache_block_size: Optional[int] = None,
71
- sliding_window: Optional[int] = None,
72
- sliding_window_layers: Optional[List[int]] = None,
73
60
  ):
74
61
  super().__init__()
62
+ self.quantization = rbln_config.quantization
75
63
  self.config = model.config
76
64
  self.is_causal_lm = getattr(model, "lm_head", None) is not None
65
+ self.rbln_config = rbln_config
77
66
 
78
67
  if use_rotary_emb:
79
- rotary_embs = self.get_rotary_emb(max_seq_len=max_seq_len)
68
+ rotary_embs = self.get_rotary_emb(max_seq_len=rbln_config.max_seq_len)
80
69
  if isinstance(rotary_embs, tuple):
81
70
  self.rotary_emb_global, self.rotary_emb_local = rotary_embs
82
71
  else:
@@ -84,31 +73,13 @@ class DecoderOnlyWrapper(nn.Module):
84
73
  else:
85
74
  self.rotary_emb = None
86
75
 
87
- self.attn_impl = attn_impl
88
- self.kvcache_block_size = kvcache_block_size
89
- self.use_attention_mask = use_attention_mask
90
- self.use_position_ids = use_position_ids
91
- self.use_inputs_embeds = use_inputs_embeds
92
- self.sliding_window_layers = sliding_window_layers
93
- self.cache_impl = cache_impl
94
- self.use_global_attention = cache_impl in ["static", "hybrid"]
95
- self.use_local_attention = cache_impl in ["hybrid", "sliding_window"]
96
- self.sliding_window = sliding_window
97
-
98
- if self.attn_impl == "flash_attn":
99
- self.kvcache_partition_len = kvcache_partition_len or DEFAULT_FLASH_ATTN_PARTITION_LENGTH
100
- elif self.attn_impl == "eager":
101
- self.kvcache_partition_len = None
102
- else:
103
- raise ValueError(f"Unknown attn_impl : {self.attn_impl}")
104
-
105
- if kvcache_partition_len and kvcache_partition_len > max_seq_len:
76
+ if rbln_config.kvcache_partition_len and rbln_config.kvcache_partition_len > rbln_config.max_seq_len:
106
77
  raise ValueError(
107
- f"kvcache_partition_len({kvcache_partition_len}) should be lower"
108
- f" or equal to max_seq_len({max_seq_len})!"
78
+ f"kvcache_partition_len({rbln_config.kvcache_partition_len}) should be lower"
79
+ f" or equal to max_seq_len({rbln_config.max_seq_len})!"
109
80
  )
110
81
 
111
- self.model = self.convert_to_rbln_class(model, max_seq_len)
82
+ self.model = self.convert_to_rbln_class(model, rbln_config.max_seq_len)
112
83
  self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or getattr(self.config, "n_layer")
113
84
  self._phase = "prefill"
114
85
 
@@ -139,17 +110,9 @@ class DecoderOnlyWrapper(nn.Module):
139
110
  def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int):
140
111
  new_layers = []
141
112
  for layer_idx, layer in enumerate(self.get_decoder_layers(model)):
142
- is_sliding = layer_idx in self.sliding_window_layers
113
+ is_sliding = layer_idx in self.rbln_config.sliding_window_layers
143
114
  new_self_attn = self.get_rbln_attn_class()(
144
- self.get_attn_layer(layer),
145
- self.use_attention_mask if not is_sliding else True,
146
- self.use_position_ids,
147
- kvcache_block_size=self.sliding_window
148
- if layer_idx in self.sliding_window_layers
149
- else self.kvcache_block_size,
150
- is_sliding=is_sliding,
151
- attn_impl=self.attn_impl if not is_sliding else "eager",
152
- kvcache_partition_len=self.kvcache_partition_len,
115
+ self.get_attn_layer(layer), self.rbln_config, is_sliding=is_sliding
153
116
  )
154
117
  new_layer = self.get_rbln_layer_class()(layer, new_self_attn)
155
118
  new_layers.append(new_layer)
@@ -157,11 +120,8 @@ class DecoderOnlyWrapper(nn.Module):
157
120
  new_model = self.get_rbln_model_class()(
158
121
  self.get_model_layer(model),
159
122
  new_layers,
160
- partition_len=self.kvcache_partition_len,
161
- max_seq_len=max_seq_len,
162
- kvcache_block_size=self.kvcache_block_size,
123
+ self.rbln_config,
163
124
  use_learned_pos_emb=self.__class__._use_learned_pos_emb,
164
- sliding_window_layers=self.sliding_window_layers,
165
125
  )
166
126
 
167
127
  if self.is_causal_lm:
@@ -181,19 +141,19 @@ class DecoderOnlyWrapper(nn.Module):
181
141
 
182
142
  def prepare_forward_args(self, *args):
183
143
  args = list(args)
184
- input_ids = None if self.use_inputs_embeds else args.pop(0)
185
- inputs_embeds = args.pop(0) if self.use_inputs_embeds else None
144
+ input_ids = None if self.rbln_config.use_inputs_embeds else args.pop(0)
145
+ inputs_embeds = args.pop(0) if self.rbln_config.use_inputs_embeds else None
186
146
  cache_position = args.pop(0)
187
- global_block_tables = args.pop(0) if self.use_global_attention else None
188
- local_block_tables = args.pop(0) if self.use_local_attention else None
147
+ global_block_tables = args.pop(0) if self.rbln_config.use_global_attention else None
148
+ local_block_tables = args.pop(0) if self.rbln_config.use_local_attention else None
189
149
  query_position = (
190
150
  args.pop(0)
191
151
  # query_position usage: 1. causal_lm prefill or 2. sliding_window cache_position
192
- if ("prefill" in self.phase and (self.is_causal_lm or self.use_local_attention))
152
+ if ("prefill" in self.phase and (self.is_causal_lm or self.rbln_config.use_local_attention))
193
153
  else None
194
154
  )
195
- attention_mask = args.pop(0) if self.use_attention_mask else None
196
- position_ids = args.pop(0) if self.use_position_ids else None
155
+ attention_mask = args.pop(0) if self.rbln_config.use_attention_mask else None
156
+ position_ids = args.pop(0) if self.rbln_config.use_position_ids else None
197
157
  past_key_values = args
198
158
 
199
159
  if len(past_key_values) != 2 * self.num_hidden_layers:
@@ -345,6 +305,8 @@ class DecoderOnlyModel(nn.Module):
345
305
  Args:
346
306
  model: Original Huggingface model to adapt
347
307
  layers (List[DecoderOnlyLayer]): Modified transformer layers optimized for RBLN
308
+ rbln_config: RBLN model configuration
309
+ use_learned_pos_emb: Whether to use learned position embeddings (class-specific override)
348
310
 
349
311
  Attributes:
350
312
  _original_mod: Reference to original Huggingface model
@@ -356,21 +318,19 @@ class DecoderOnlyModel(nn.Module):
356
318
  self,
357
319
  model,
358
320
  layers: List["DecoderOnlyLayer"],
359
- partition_len=None,
360
- max_seq_len=None,
361
- kvcache_block_size=None,
321
+ rbln_config: "RBLNDecoderOnlyModelConfig",
362
322
  use_learned_pos_emb=None,
363
- sliding_window_layers=None,
364
323
  ):
365
324
  super().__init__()
366
325
  self._original_mod = model
367
326
  self.layers = nn.ModuleList(layers)
327
+ self.rbln_config = rbln_config
368
328
  self._phase = "prefill"
369
- self.partition_len = partition_len
370
- self.kvcache_block_size = kvcache_block_size
371
- self.max_seq_len = max_seq_len
329
+ self.partition_len = rbln_config.kvcache_partition_len
330
+ self.kvcache_block_size = rbln_config.kvcache_block_size
331
+ self.max_seq_len = rbln_config.max_seq_len
372
332
  self.use_learned_pos_emb = use_learned_pos_emb
373
- self.sliding_window_layers = sliding_window_layers
333
+ self.sliding_window_layers = rbln_config.sliding_window_layers
374
334
 
375
335
  @property
376
336
  def phase(self):
@@ -600,25 +560,19 @@ class DecoderOnlyAttention(nn.Module):
600
560
 
601
561
  Args:
602
562
  self_attn: Original attention module from the base model
603
- use_attention_mask: Whether to use attention mask
604
- use_position_ids: Whether to use position ids
605
- kvcache_block_size: Block size for KV cache
563
+ rbln_config: RBLN model configuration containing attention parameters
606
564
  is_sliding: Whether this is sliding window attention
607
- attn_impl: Attention implementation type ("eager" or "flash_attn")
608
565
  """
609
566
 
610
567
  def __init__(
611
568
  self,
612
569
  self_attn,
613
- use_attention_mask,
614
- use_position_ids,
615
- kvcache_block_size,
570
+ rbln_config: "RBLNDecoderOnlyModelConfig",
616
571
  is_sliding=False,
617
- attn_impl="eager",
618
- kvcache_partition_len=None,
619
572
  ):
620
573
  super().__init__()
621
574
  self._original_mod = self_attn
575
+ self.rbln_config = rbln_config
622
576
  self.layer_idx = self_attn.layer_idx
623
577
  self.num_heads = getattr(self._original_mod, "num_heads", None) or getattr(
624
578
  self._original_mod.config, "num_attention_heads"
@@ -626,6 +580,7 @@ class DecoderOnlyAttention(nn.Module):
626
580
  self.head_dim = self._original_mod.head_dim
627
581
  self._phase = "prefill"
628
582
  self.scale = torch.tensor(self.get_attn_scale())
583
+ self.quantization = rbln_config.quantization
629
584
 
630
585
  if hasattr(self._original_mod, "num_key_value_heads"):
631
586
  self.num_key_value_heads = self._original_mod.num_key_value_heads
@@ -634,14 +589,14 @@ class DecoderOnlyAttention(nn.Module):
634
589
  else:
635
590
  self.num_key_value_heads = self.num_heads
636
591
 
637
- self.use_attention_mask = use_attention_mask
638
- self.use_position_ids = use_position_ids
592
+ self.use_attention_mask = rbln_config.use_attention_mask if not is_sliding else True
593
+ self.use_position_ids = rbln_config.use_position_ids
639
594
  self.is_sliding = is_sliding
640
- self.attn_impl = attn_impl
641
- self.kvcache_partition_len = kvcache_partition_len
595
+ self.attn_impl = rbln_config.attn_impl if not is_sliding else "eager"
596
+ self.kvcache_partition_len = getattr(rbln_config, "kvcache_partition_len", None)
597
+ self.kvcache_block_size = rbln_config.sliding_window if is_sliding else rbln_config.kvcache_block_size
642
598
 
643
599
  setattr(self, self.get_attention_name(), self.create_attention_op())
644
- self.kvcache_block_size = kvcache_block_size
645
600
  self.__post_init__()
646
601
 
647
602
  def get_attention_name(self):
@@ -681,6 +636,7 @@ class DecoderOnlyAttention(nn.Module):
681
636
  self.kvcache_partition_len,
682
637
  self.use_attention_mask,
683
638
  self.use_position_ids,
639
+ self.quantization,
684
640
  )
685
641
  elif self.attn_impl == "eager":
686
642
  return AttentionOp(
@@ -689,6 +645,7 @@ class DecoderOnlyAttention(nn.Module):
689
645
  self.num_key_value_heads,
690
646
  self.use_attention_mask,
691
647
  self.use_position_ids,
648
+ self.quantization,
692
649
  )
693
650
  else:
694
651
  raise NotImplementedError(f"Unknown attention implementation: {self.attn_impl}")
@@ -719,6 +676,16 @@ class DecoderOnlyAttention(nn.Module):
719
676
  def get_attn_scale(self):
720
677
  return 1 / math.sqrt(self.head_dim)
721
678
 
679
+ def maybe_get_kvcache_scale(self) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
680
+ if hasattr(self, "k_proj") and hasattr(self, "v_proj"):
681
+ k_scale = getattr(self.k_proj, "k_scale", None)
682
+ v_scale = getattr(self.v_proj, "v_scale", None)
683
+ else:
684
+ k_scale = None
685
+ v_scale = None
686
+
687
+ return k_scale, v_scale
688
+
722
689
  def forward(
723
690
  self,
724
691
  hidden_states: torch.Tensor,
@@ -748,6 +715,8 @@ class DecoderOnlyAttention(nn.Module):
748
715
  if batch_size > 1 and "prefill" in self.phase:
749
716
  raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
750
717
 
718
+ k_scale, v_scale = self.maybe_get_kvcache_scale()
719
+
751
720
  attn_output = self.get_attention_op()(
752
721
  query_states,
753
722
  key_states,
@@ -759,6 +728,8 @@ class DecoderOnlyAttention(nn.Module):
759
728
  scale=self.scale,
760
729
  block_tables=block_tables,
761
730
  block_size=self.kvcache_block_size,
731
+ k_scale=k_scale,
732
+ v_scale=v_scale,
762
733
  )
763
734
 
764
735
  attn_outputs = self.o_proj(attn_output)
@@ -775,7 +746,13 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
775
746
 
776
747
  class AttentionOp(nn.Module):
777
748
  def __init__(
778
- self, num_heads: int, head_dim: int, num_key_value_heads: int, use_attention_mask: bool, use_position_ids: bool
749
+ self,
750
+ num_heads: int,
751
+ head_dim: int,
752
+ num_key_value_heads: int,
753
+ use_attention_mask: bool,
754
+ use_position_ids: bool,
755
+ quantization: Optional[RBLNQuantizationConfig] = None,
779
756
  ):
780
757
  super().__init__()
781
758
  self.num_heads = num_heads
@@ -784,10 +761,10 @@ class AttentionOp(nn.Module):
784
761
  self.phase = "prefill"
785
762
  self.use_attention_mask = use_attention_mask
786
763
  self.use_position_ids = use_position_ids
764
+ self.quantization = quantization
787
765
 
788
766
  def get_attn_op_name(self):
789
767
  phase = "decode" if self.phase == "decode" else "prefill"
790
-
791
768
  if self.use_attention_mask and not self.use_position_ids:
792
769
  attn_op_name = "paged_attn_"
793
770
  else:
@@ -795,6 +772,9 @@ class AttentionOp(nn.Module):
795
772
 
796
773
  attn_op_name += phase
797
774
 
775
+ if self.quantization and self.quantization.kv_caches == "fp8":
776
+ attn_op_name += "_kv_fp8"
777
+
798
778
  return attn_op_name
799
779
 
800
780
  def forward(
@@ -809,6 +789,8 @@ class AttentionOp(nn.Module):
809
789
  scale: torch.Tensor,
810
790
  block_tables: torch.Tensor,
811
791
  block_size: int,
792
+ k_scale: Optional[torch.Tensor] = None,
793
+ v_scale: Optional[torch.Tensor] = None,
812
794
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
813
795
  """Compute attention with static shapes and explicit cache management.
814
796
 
@@ -821,6 +803,10 @@ class AttentionOp(nn.Module):
821
803
  past_value_state: Previous value cache states
822
804
  seq_position: Current position in sequence
823
805
  scale: Scale applied to attn weights
806
+ block_tables: Block tables for paged attention
807
+ block_size: Block size for paged attention
808
+ k_scale: Scale applied to key
809
+ v_scale: Scale applied to value
824
810
 
825
811
  Returns:
826
812
  Tensor: attention_output: [batch, num_heads, seq_len, head_dim]
@@ -864,6 +850,12 @@ class AttentionOp(nn.Module):
864
850
  if not self.use_attention_mask or self.use_position_ids:
865
851
  op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
866
852
 
853
+ if self.quantization and self.quantization.kv_caches == "fp8":
854
+ if past_key_state.dtype != torch.float8_e4m3fn:
855
+ raise ValueError(f"Unsupported KVCaches type: {past_key_state.dtype}")
856
+ op_args["k_scale"] = k_scale
857
+ op_args["v_scale"] = v_scale
858
+
867
859
  attn_op_name = self.get_attn_op_name()
868
860
  attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
869
861
  if attn_op is None:
@@ -886,6 +878,7 @@ class FlashAttentionOp(AttentionOp):
886
878
  kvcache_partition_len: int,
887
879
  use_attention_mask: bool,
888
880
  use_position_ids: bool,
881
+ quantization: Optional[RBLNQuantizationConfig] = None,
889
882
  ):
890
883
  super().__init__(
891
884
  num_heads=num_heads,
@@ -893,6 +886,7 @@ class FlashAttentionOp(AttentionOp):
893
886
  num_key_value_heads=num_key_value_heads,
894
887
  use_attention_mask=use_attention_mask,
895
888
  use_position_ids=use_position_ids,
889
+ quantization=quantization,
896
890
  )
897
891
  self.kvcache_partition_size = kvcache_partition_len
898
892
 
@@ -905,6 +899,9 @@ class FlashAttentionOp(AttentionOp):
905
899
 
906
900
  attn_op_name += phase
907
901
 
902
+ if self.quantization and self.quantization.kv_caches == "fp8":
903
+ attn_op_name += "_kv_fp8"
904
+
908
905
  return attn_op_name
909
906
 
910
907
  def forward(
@@ -919,6 +916,8 @@ class FlashAttentionOp(AttentionOp):
919
916
  scale,
920
917
  block_tables,
921
918
  block_size,
919
+ k_scale=None,
920
+ v_scale=None,
922
921
  ):
923
922
  # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
924
923
  key_state = key_state.unsqueeze(2)
@@ -959,6 +958,12 @@ class FlashAttentionOp(AttentionOp):
959
958
  if not self.use_attention_mask or self.use_position_ids:
960
959
  op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
961
960
 
961
+ if self.quantization and self.quantization.kv_caches == "fp8":
962
+ if past_key_state.dtype != torch.float8_e4m3fn:
963
+ raise ValueError(f"Unsupported KVCaches type: {past_key_state.dtype}")
964
+ op_args["k_scale"] = k_scale
965
+ op_args["v_scale"] = v_scale
966
+
962
967
  attn_op_name = self.get_attn_op_name()
963
968
  attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
964
969
  if attn_op is None:
@@ -986,14 +991,19 @@ class SlidingWindowAttentionOp(AttentionOp):
986
991
  query_state: torch.Tensor,
987
992
  key_state: torch.Tensor,
988
993
  value_state: torch.Tensor,
989
- attn_mask: torch.Tensor,
994
+ attn_mask: Optional[torch.Tensor],
990
995
  past_key_state: torch.Tensor,
991
996
  past_value_state: torch.Tensor,
992
997
  seq_position: Tuple[torch.Tensor],
993
998
  scale: torch.Tensor,
994
999
  block_tables: torch.Tensor,
995
1000
  block_size: int,
1001
+ k_scale: Optional[torch.Tensor] = None,
1002
+ v_scale: Optional[torch.Tensor] = None,
996
1003
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1004
+ assert self.quantization is None, "Sliding window attention does not support quantization"
1005
+ assert k_scale is None and v_scale is None, "Sliding window attention does not support quantization"
1006
+
997
1007
  # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
998
1008
  key_state = key_state.unsqueeze(2)
999
1009
  value_state = value_state.unsqueeze(2)
@@ -1025,8 +1035,7 @@ class SlidingWindowAttentionOp(AttentionOp):
1025
1035
  }
1026
1036
 
1027
1037
  if self.phase == "prefill" or self.phase == "image_prefill":
1028
- if not self.use_attention_mask or self.use_position_ids:
1029
- op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
1038
+ op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
1030
1039
 
1031
1040
  attn_op_name = self.get_attn_op_name()
1032
1041
  attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)