optimum-rbln 0.8.1rc0__py3-none-any.whl → 0.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (120) hide show
  1. optimum/rbln/__init__.py +58 -9
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +24 -5
  4. optimum/rbln/diffusers/configurations/models/__init__.py +1 -1
  5. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +2 -2
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +5 -3
  7. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +2 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +2 -2
  9. optimum/rbln/diffusers/configurations/models/{configuration_cosmos_transformer.py → configuration_transformer_cosmos.py} +7 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +2 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +2 -2
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +2 -2
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +3 -3
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +10 -6
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +4 -4
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +2 -2
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +2 -2
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +2 -2
  19. optimum/rbln/diffusers/modeling_diffusers.py +4 -5
  20. optimum/rbln/diffusers/models/__init__.py +3 -13
  21. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +1 -0
  22. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1 -0
  23. optimum/rbln/diffusers/models/autoencoders/vq_model.py +1 -0
  24. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +1 -1
  25. optimum/rbln/diffusers/pipelines/__init__.py +1 -5
  26. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +12 -4
  27. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +4 -26
  28. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +2 -2
  29. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +2 -2
  30. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  31. optimum/rbln/modeling.py +4 -5
  32. optimum/rbln/modeling_base.py +18 -14
  33. optimum/rbln/ops/kv_cache_update.py +5 -0
  34. optimum/rbln/ops/linear.py +7 -0
  35. optimum/rbln/transformers/__init__.py +60 -0
  36. optimum/rbln/transformers/configuration_generic.py +4 -4
  37. optimum/rbln/transformers/modeling_attention_utils.py +252 -0
  38. optimum/rbln/transformers/modeling_generic.py +1 -4
  39. optimum/rbln/transformers/models/__init__.py +45 -30
  40. optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
  41. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
  42. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -5
  43. optimum/rbln/transformers/models/clip/configuration_clip.py +14 -3
  44. optimum/rbln/transformers/models/clip/modeling_clip.py +123 -28
  45. optimum/rbln/transformers/models/colpali/colpali_architecture.py +1 -4
  46. optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
  47. optimum/rbln/transformers/models/colpali/modeling_colpali.py +2 -10
  48. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -2
  49. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +214 -45
  50. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -454
  51. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +579 -362
  52. optimum/rbln/transformers/models/exaone/exaone_architecture.py +17 -42
  53. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  54. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  55. optimum/rbln/transformers/models/gemma/gemma_architecture.py +3 -44
  56. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  57. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +21 -9
  58. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +9 -63
  59. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +200 -292
  60. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  61. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  62. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +19 -24
  63. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  64. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +2 -2
  65. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -9
  66. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  67. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  68. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  69. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  70. optimum/rbln/transformers/models/llava/configuration_llava.py +54 -0
  71. optimum/rbln/transformers/models/llava/modeling_llava.py +419 -0
  72. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +20 -3
  73. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +6 -16
  74. optimum/rbln/transformers/models/midm/midm_architecture.py +14 -22
  75. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  76. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  77. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  78. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  79. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  80. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  81. optimum/rbln/transformers/models/opt/modeling_opt.py +41 -1
  82. optimum/rbln/transformers/models/opt/opt_architecture.py +16 -25
  83. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  84. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +34 -0
  85. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +69 -0
  86. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  87. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  88. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  89. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  90. optimum/rbln/transformers/models/phi/phi_architecture.py +16 -22
  91. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  92. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  93. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +315 -0
  94. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  95. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  96. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  97. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  98. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +3 -3
  99. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -15
  100. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +1 -4
  101. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  102. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  103. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  104. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  105. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -12
  106. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +3 -1
  107. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  108. optimum/rbln/transformers/models/siglip/modeling_siglip.py +2 -2
  109. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +2 -2
  110. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -5
  111. optimum/rbln/transformers/models/whisper/configuration_whisper.py +3 -12
  112. optimum/rbln/transformers/models/whisper/modeling_whisper.py +8 -2
  113. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  114. optimum/rbln/utils/depreacate_utils.py +16 -0
  115. optimum/rbln/utils/hub.py +8 -47
  116. optimum/rbln/utils/runtime_utils.py +31 -5
  117. {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2.dist-info}/METADATA +1 -1
  118. {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2.dist-info}/RECORD +120 -103
  119. {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2.dist-info}/WHEEL +0 -0
  120. {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2.dist-info}/licenses/LICENSE +0 -0
@@ -20,108 +20,13 @@ from torch import nn
20
20
  from transformers import PretrainedConfig, PreTrainedModel
21
21
 
22
22
  from ....utils import logging
23
+ from ...modeling_attention_utils import DEFAULT_FLASH_ATTN_PARTITION_LENGTH
23
24
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
24
25
  from .configuration_decoderonly import CacheImplType
25
26
 
26
27
 
27
28
  logger = logging.get_logger(__name__)
28
29
 
29
- DEFAULT_FLASH_ATTN_PARTITION_LENGTH = 16_384
30
- DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH = 32_768
31
- MIN_FLASH_ATTN_MAX_SEQ_LEN = 8_192
32
- MIN_FLASH_ATTN_PARTITION_LENGTH = 4_096
33
- MAX_FLASH_ATTN_PARTITION_LENGTH = 32_768
34
- MAX_SLIDING_WINDOW_SIZE = 32_768
35
-
36
-
37
- def set_default_values(
38
- attn_impl: Optional[str] = None,
39
- kvcache_partition_len: Optional[int] = None,
40
- kvcache_block_size: Optional[int] = None,
41
- max_seq_len: Optional[int] = None,
42
- ) -> Tuple[str, int, int]:
43
- if attn_impl is None:
44
- attn_impl = "eager"
45
-
46
- if kvcache_partition_len is not None:
47
- if attn_impl == "eager":
48
- attn_impl = "flash_attn"
49
- logger.warning(
50
- "A non-null `kvcache_partition_len` was provided, but `attn_impl` was not explicitly set or "
51
- "set to 'eager'. Since KV cache partitioning is only supported with flash attention, "
52
- "`attn_impl` has been automatically switched to 'flash_attn'."
53
- )
54
-
55
- if kvcache_partition_len is None and attn_impl == "flash_attn":
56
- kvcache_partition_len = DEFAULT_FLASH_ATTN_PARTITION_LENGTH
57
-
58
- if kvcache_block_size is None:
59
- if attn_impl == "eager":
60
- kvcache_block_size = max_seq_len
61
- else:
62
- kvcache_block_size = kvcache_partition_len
63
-
64
- return attn_impl, kvcache_partition_len, kvcache_block_size
65
-
66
-
67
- def validate_attention_method(attn_impl: str, kvcache_partition_len: int, kvcache_block_size: int, max_seq_len: int):
68
- if attn_impl not in ["eager", "flash_attn"]:
69
- raise ValueError(f"Unknown `attn_impl` : {attn_impl}. (Available : 'eager', 'flash_attn`)")
70
-
71
- ## Checking Constraints...
72
- # Constraint of eager attention:
73
- # - `max_seq_len` <= 32k
74
-
75
- # Constraints of flash attention:
76
- # 1. `max_seq_len` should be multiple of `partition_len`.
77
- # 2. 4k <= `partition_len` <= 32k.
78
- # 3. `max_seq_len` should be larger then 8k.
79
- if attn_impl == "eager" and max_seq_len > DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH:
80
- raise ValueError(
81
- f"`max_seq_len` is set to {max_seq_len}, "
82
- f"which exceeds the limit of {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} for 'eager' attention. "
83
- f"Please reduce the `max_seq_len` to {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} or lower,"
84
- " or consider switching `attn_impl` to 'flash_attn' for larger sequence lengths."
85
- )
86
-
87
- if attn_impl == "flash_attn":
88
- if max_seq_len // kvcache_partition_len < 2 or max_seq_len % kvcache_partition_len != 0:
89
- raise ValueError(
90
- f"`max_seq_len` ({max_seq_len}) must be a multiple of `kvcache_partition_len` ({kvcache_partition_len}) "
91
- f"when using 'flash_attn'. Please adjust either value to meet this requirement."
92
- )
93
- elif not (MIN_FLASH_ATTN_PARTITION_LENGTH <= kvcache_partition_len <= MAX_FLASH_ATTN_PARTITION_LENGTH):
94
- raise ValueError(
95
- f"`kvcache_partition_len` ({kvcache_partition_len}) is out of the supported range for 'flash_attn' "
96
- f"({MIN_FLASH_ATTN_PARTITION_LENGTH} <= `kvcache_partition_len` <= {MAX_FLASH_ATTN_PARTITION_LENGTH}). "
97
- f"Please provide a valid value within this range."
98
- )
99
- elif max_seq_len < MIN_FLASH_ATTN_MAX_SEQ_LEN:
100
- raise ValueError(
101
- f"`max_seq_len` ({max_seq_len}) is too small for 'flash_attn'. The minimum "
102
- f"supported value is {MIN_FLASH_ATTN_MAX_SEQ_LEN}. Please increase `max_seq_len` to meet "
103
- "this requirement, or consider switching `attn_impl` to 'eager' for shorter lengths."
104
- )
105
-
106
- if kvcache_block_size is not None:
107
- if attn_impl == "flash_attn" and kvcache_partition_len != kvcache_block_size:
108
- raise ValueError(
109
- f" When using 'flash attention', the `kvcache_block_size` ({kvcache_block_size}) "
110
- f"must always be set equal to the `kvcache_partition_len` {kvcache_partition_len}."
111
- )
112
- elif attn_impl == "eager" and kvcache_block_size != max_seq_len:
113
- raise ValueError(
114
- f" When using 'eager attention', the `kvcache_block_size` ({kvcache_block_size}) "
115
- f"must always be set equal to the `max_seq_len` {max_seq_len}."
116
- )
117
-
118
-
119
- def validate_sliding_window_size(sliding_window: int, prefill_chunk_size: int):
120
- if sliding_window > MAX_SLIDING_WINDOW_SIZE - prefill_chunk_size:
121
- raise ValueError(
122
- f"Sliding window size ({sliding_window}) must be less than 32768 - prefill_chunk_size ({32768 - prefill_chunk_size})"
123
- )
124
-
125
30
 
126
31
  class DecoderOnlyWrapper(nn.Module):
127
32
  """A wrapper class for decoder-only language models that handles RBLN-specific optimizations and requirements.
@@ -149,9 +54,11 @@ class DecoderOnlyWrapper(nn.Module):
149
54
  This is only relevant if `attn_impl` is set to "flash_attn`
150
55
  """
151
56
 
57
+ _use_learned_pos_emb = False
58
+
152
59
  def __init__(
153
60
  self,
154
- causal_lm: PreTrainedModel,
61
+ model: PreTrainedModel,
155
62
  max_seq_len: int,
156
63
  use_rotary_emb: bool,
157
64
  attn_impl: str,
@@ -159,14 +66,14 @@ class DecoderOnlyWrapper(nn.Module):
159
66
  use_inputs_embeds: bool,
160
67
  use_attention_mask: bool,
161
68
  use_position_ids: bool,
162
- use_learned_pos_emb: Optional[bool] = None,
163
69
  kvcache_partition_len: Optional[int] = None,
164
70
  kvcache_block_size: Optional[int] = None,
165
71
  sliding_window: Optional[int] = None,
166
72
  sliding_window_layers: Optional[List[int]] = None,
167
73
  ):
168
74
  super().__init__()
169
- self.config = causal_lm.config
75
+ self.config = model.config
76
+ self.is_causal_lm = getattr(model, "lm_head", None) is not None
170
77
 
171
78
  if use_rotary_emb:
172
79
  rotary_embs = self.get_rotary_emb(max_seq_len=max_seq_len)
@@ -182,9 +89,10 @@ class DecoderOnlyWrapper(nn.Module):
182
89
  self.use_attention_mask = use_attention_mask
183
90
  self.use_position_ids = use_position_ids
184
91
  self.use_inputs_embeds = use_inputs_embeds
185
- self.use_learned_pos_emb = use_learned_pos_emb
186
92
  self.sliding_window_layers = sliding_window_layers
187
93
  self.cache_impl = cache_impl
94
+ self.use_global_attention = cache_impl in ["static", "hybrid"]
95
+ self.use_local_attention = cache_impl in ["hybrid", "sliding_window"]
188
96
  self.sliding_window = sliding_window
189
97
 
190
98
  if self.attn_impl == "flash_attn":
@@ -200,59 +108,67 @@ class DecoderOnlyWrapper(nn.Module):
200
108
  f" or equal to max_seq_len({max_seq_len})!"
201
109
  )
202
110
 
203
- self.causal_lm = self.convert_to_rbln_causal_lm(causal_lm, max_seq_len)
111
+ self.model = self.convert_to_rbln_class(model, max_seq_len)
204
112
  self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or getattr(self.config, "n_layer")
205
113
  self._phase = "prefill"
206
114
 
207
115
  def get_rotary_emb(self, max_seq_len):
208
116
  return RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
209
117
 
210
- def convert_to_rbln_causal_lm(self, causal_lm: PreTrainedModel, max_seq_len: int):
118
+ def get_decoder_layers(self, model: PreTrainedModel):
119
+ return model.model.layers if self.is_causal_lm else model.layers
120
+
121
+ def get_attn_layer(self, layer: nn.Module):
122
+ return layer.self_attn
123
+
124
+ def get_model_layer(self, model: PreTrainedModel):
125
+ return model.model if self.is_causal_lm else model
126
+
127
+ def get_rbln_attn_class(self):
128
+ return DecoderOnlyAttention
129
+
130
+ def get_rbln_layer_class(self):
131
+ return DecoderOnlyLayer
132
+
133
+ def get_rbln_model_class(self):
134
+ return DecoderOnlyModel
135
+
136
+ def get_rbln_causal_lm_class(self):
137
+ return DecoderOnlyForCausalLM
138
+
139
+ def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int):
211
140
  new_layers = []
212
- for layer_idx, layer in enumerate(causal_lm.model.layers):
213
- if layer_idx in self.sliding_window_layers:
214
- # Flash attention is not yet supported for sliding window attention.
215
- new_self_attn = DecoderOnlyAttention(
216
- layer.self_attn,
217
- self.use_attention_mask,
218
- self.use_position_ids,
219
- kvcache_block_size=self.sliding_window,
220
- is_sliding=True,
221
- )
222
- else:
223
- if self.attn_impl == "eager":
224
- new_self_attn = DecoderOnlyAttention(
225
- layer.self_attn,
226
- self.use_attention_mask,
227
- self.use_position_ids,
228
- kvcache_block_size=self.kvcache_block_size,
229
- is_sliding=False,
230
- )
231
- elif self.attn_impl == "flash_attn":
232
- new_self_attn = DecoderOnlyFlashAttention(
233
- layer.self_attn,
234
- kvcache_partition_len=self.kvcache_partition_len,
235
- kvcache_block_size=self.kvcache_block_size,
236
- use_attention_mask=self.use_attention_mask,
237
- use_position_ids=self.use_position_ids,
238
- )
239
- else:
240
- raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
241
-
242
- new_layer = DecoderOnlyLayer(layer, new_self_attn)
141
+ for layer_idx, layer in enumerate(self.get_decoder_layers(model)):
142
+ is_sliding = layer_idx in self.sliding_window_layers
143
+ new_self_attn = self.get_rbln_attn_class()(
144
+ self.get_attn_layer(layer),
145
+ self.use_attention_mask if not is_sliding else True,
146
+ self.use_position_ids,
147
+ kvcache_block_size=self.sliding_window
148
+ if layer_idx in self.sliding_window_layers
149
+ else self.kvcache_block_size,
150
+ is_sliding=is_sliding,
151
+ attn_impl=self.attn_impl if not is_sliding else "eager",
152
+ kvcache_partition_len=self.kvcache_partition_len,
153
+ )
154
+ new_layer = self.get_rbln_layer_class()(layer, new_self_attn)
243
155
  new_layers.append(new_layer)
244
156
 
245
- new_model = DecoderOnlyModel(
246
- causal_lm.model,
157
+ new_model = self.get_rbln_model_class()(
158
+ self.get_model_layer(model),
247
159
  new_layers,
248
160
  partition_len=self.kvcache_partition_len,
249
161
  max_seq_len=max_seq_len,
250
162
  kvcache_block_size=self.kvcache_block_size,
251
- use_learned_pos_emb=self.use_learned_pos_emb,
163
+ use_learned_pos_emb=self.__class__._use_learned_pos_emb,
252
164
  sliding_window_layers=self.sliding_window_layers,
253
165
  )
254
- new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
255
- return new_causal_lm
166
+
167
+ if self.is_causal_lm:
168
+ new_model = self.get_rbln_causal_lm_class()(model, new_model)
169
+ return new_model
170
+ else:
171
+ return new_model
256
172
 
257
173
  @property
258
174
  def phase(self) -> str:
@@ -261,16 +177,21 @@ class DecoderOnlyWrapper(nn.Module):
261
177
  @phase.setter
262
178
  def phase(self, phase: str):
263
179
  self._phase = phase
264
- self.causal_lm.phase = phase
180
+ self.model.phase = phase
265
181
 
266
182
  def prepare_forward_args(self, *args):
267
183
  args = list(args)
268
184
  input_ids = None if self.use_inputs_embeds else args.pop(0)
269
185
  inputs_embeds = args.pop(0) if self.use_inputs_embeds else None
270
186
  cache_position = args.pop(0)
271
- global_block_tables = args.pop(0) if self.cache_impl in ["hybrid", "static"] else None
272
- local_block_tables = args.pop(0) if self.cache_impl in ["hybrid", "sliding_window"] else None
273
- query_position = args.pop(0) if "prefill" in self.phase else None
187
+ global_block_tables = args.pop(0) if self.use_global_attention else None
188
+ local_block_tables = args.pop(0) if self.use_local_attention else None
189
+ query_position = (
190
+ args.pop(0)
191
+ # query_position usage: 1. causal_lm prefill or 2. sliding_window cache_position
192
+ if ("prefill" in self.phase and (self.is_causal_lm or self.use_local_attention))
193
+ else None
194
+ )
274
195
  attention_mask = args.pop(0) if self.use_attention_mask else None
275
196
  position_ids = args.pop(0) if self.use_position_ids else None
276
197
  past_key_values = args
@@ -322,7 +243,7 @@ class DecoderOnlyWrapper(nn.Module):
322
243
  rotary_emb,
323
244
  ) = self.prepare_forward_args(*args)
324
245
 
325
- logit = self.causal_lm(
246
+ logit = self.model(
326
247
  input_ids=input_ids,
327
248
  inputs_embeds=inputs_embeds,
328
249
  attention_mask=attention_mask,
@@ -679,9 +600,23 @@ class DecoderOnlyAttention(nn.Module):
679
600
 
680
601
  Args:
681
602
  self_attn: Original attention module from the base model
603
+ use_attention_mask: Whether to use attention mask
604
+ use_position_ids: Whether to use position ids
605
+ kvcache_block_size: Block size for KV cache
606
+ is_sliding: Whether this is sliding window attention
607
+ attn_impl: Attention implementation type ("eager" or "flash_attn")
682
608
  """
683
609
 
684
- def __init__(self, self_attn, use_attention_mask, use_position_ids, kvcache_block_size, is_sliding=False):
610
+ def __init__(
611
+ self,
612
+ self_attn,
613
+ use_attention_mask,
614
+ use_position_ids,
615
+ kvcache_block_size,
616
+ is_sliding=False,
617
+ attn_impl="eager",
618
+ kvcache_partition_len=None,
619
+ ):
685
620
  super().__init__()
686
621
  self._original_mod = self_attn
687
622
  self.layer_idx = self_attn.layer_idx
@@ -702,10 +637,24 @@ class DecoderOnlyAttention(nn.Module):
702
637
  self.use_attention_mask = use_attention_mask
703
638
  self.use_position_ids = use_position_ids
704
639
  self.is_sliding = is_sliding
705
- self.attention = self.get_attention()
640
+ self.attn_impl = attn_impl
641
+ self.kvcache_partition_len = kvcache_partition_len
642
+
643
+ setattr(self, self.get_attention_name(), self.create_attention_op())
706
644
  self.kvcache_block_size = kvcache_block_size
707
645
  self.__post_init__()
708
646
 
647
+ def get_attention_name(self):
648
+ if self.is_sliding:
649
+ return "sliding_window_attention"
650
+ elif self.attn_impl == "flash_attn":
651
+ return "flash_attention"
652
+ else:
653
+ return "attention"
654
+
655
+ def get_attention_op(self):
656
+ return getattr(self, self.get_attention_name())
657
+
709
658
  @property
710
659
  def phase(self):
711
660
  return self._phase
@@ -713,17 +662,36 @@ class DecoderOnlyAttention(nn.Module):
713
662
  @phase.setter
714
663
  def phase(self, phase: str):
715
664
  self._phase = phase
716
- self.attention.phase = phase
665
+ getattr(self, self.get_attention_name()).phase = phase
717
666
 
718
- def get_attention(self):
667
+ def create_attention_op(self):
719
668
  if self.is_sliding:
720
669
  return SlidingWindowAttentionOp(
721
- self.num_heads, self.head_dim, self.num_key_value_heads, self.use_attention_mask, self.use_position_ids
670
+ self.num_heads,
671
+ self.head_dim,
672
+ self.num_key_value_heads,
673
+ self.use_attention_mask,
674
+ self.use_position_ids,
722
675
  )
723
- else:
676
+ elif self.attn_impl == "flash_attn":
677
+ return FlashAttentionOp(
678
+ self.num_heads,
679
+ self.head_dim,
680
+ self.num_key_value_heads,
681
+ self.kvcache_partition_len,
682
+ self.use_attention_mask,
683
+ self.use_position_ids,
684
+ )
685
+ elif self.attn_impl == "eager":
724
686
  return AttentionOp(
725
- self.num_heads, self.head_dim, self.num_key_value_heads, self.use_attention_mask, self.use_position_ids
687
+ self.num_heads,
688
+ self.head_dim,
689
+ self.num_key_value_heads,
690
+ self.use_attention_mask,
691
+ self.use_position_ids,
726
692
  )
693
+ else:
694
+ raise NotImplementedError(f"Unknown attention implementation: {self.attn_impl}")
727
695
 
728
696
  def __post_init__(self):
729
697
  self.q_proj = self._original_mod.q_proj
@@ -780,7 +748,7 @@ class DecoderOnlyAttention(nn.Module):
780
748
  if batch_size > 1 and "prefill" in self.phase:
781
749
  raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
782
750
 
783
- attn_output = self.attention(
751
+ attn_output = self.get_attention_op()(
784
752
  query_states,
785
753
  key_states,
786
754
  value_states,
@@ -797,6 +765,14 @@ class DecoderOnlyAttention(nn.Module):
797
765
  return attn_outputs
798
766
 
799
767
 
768
+ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
769
+ def __init__(self, *args, **kwargs):
770
+ super().__init__(*args, **kwargs)
771
+ logger.warning(
772
+ "DecoderOnlyFlashAttention is deprecated and may not work as expected. Use DecoderOnlyAttention instead."
773
+ )
774
+
775
+
800
776
  class AttentionOp(nn.Module):
801
777
  def __init__(
802
778
  self, num_heads: int, head_dim: int, num_key_value_heads: int, use_attention_mask: bool, use_position_ids: bool
@@ -809,6 +785,18 @@ class AttentionOp(nn.Module):
809
785
  self.use_attention_mask = use_attention_mask
810
786
  self.use_position_ids = use_position_ids
811
787
 
788
+ def get_attn_op_name(self):
789
+ phase = "decode" if self.phase == "decode" else "prefill"
790
+
791
+ if self.use_attention_mask and not self.use_position_ids:
792
+ attn_op_name = "paged_attn_"
793
+ else:
794
+ attn_op_name = "paged_causal_attn_"
795
+
796
+ attn_op_name += phase
797
+
798
+ return attn_op_name
799
+
812
800
  def forward(
813
801
  self,
814
802
  query_state: torch.Tensor,
@@ -857,63 +845,31 @@ class AttentionOp(nn.Module):
857
845
  self.head_dim,
858
846
  )
859
847
 
860
- if self.phase == "decode":
861
- if self.use_attention_mask and not self.use_position_ids:
862
- attn_output = torch.ops.rbln_custom_ops.paged_attn_decode(
863
- q=query_state,
864
- k=key_state,
865
- v=value_state,
866
- mask=attn_mask,
867
- kcache=past_key_state.unsqueeze(2),
868
- vcache=past_value_state.unsqueeze(2),
869
- seq=seq_position,
870
- scale=scale,
871
- block_table=block_tables,
872
- block_size=block_size,
873
- )
874
- else:
875
- attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_decode(
876
- q=query_state,
877
- k=key_state,
878
- v=value_state,
879
- kcache=past_key_state.unsqueeze(2),
880
- vcache=past_value_state.unsqueeze(2),
881
- seq=seq_position,
882
- scale=scale,
883
- block_table=block_tables,
884
- block_size=block_size,
885
- mask=attn_mask if self.use_position_ids else None,
886
- )
887
-
888
- else:
889
- if self.use_attention_mask and not self.use_position_ids:
890
- attn_output = torch.ops.rbln_custom_ops.paged_attn_prefill(
891
- q=query_state,
892
- k=key_state,
893
- v=value_state,
894
- mask=attn_mask,
895
- kcache=past_key_state.unsqueeze(2),
896
- vcache=past_value_state.unsqueeze(2),
897
- seq=seq_position,
898
- scale=scale,
899
- block_table=block_tables,
900
- block_size=block_size,
901
- )
902
- else:
903
- attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_prefill(
904
- q=query_state,
905
- k=key_state,
906
- v=value_state,
907
- kcache=past_key_state.unsqueeze(2),
908
- vcache=past_value_state.unsqueeze(2),
909
- seq=seq_position,
910
- scale=scale,
911
- block_table=block_tables,
912
- block_size=block_size,
913
- is_bidirectional=True if self.phase == "image_prefill" else False, # FIXME, Hard-coded for Gemma3.
914
- mask=attn_mask if self.use_position_ids else None,
915
- )
916
-
848
+ op_args = {
849
+ "q": query_state,
850
+ "k": key_state,
851
+ "v": value_state,
852
+ "kcache": past_key_state.unsqueeze(2),
853
+ "vcache": past_value_state.unsqueeze(2),
854
+ "seq": seq_position,
855
+ "scale": scale,
856
+ "block_table": block_tables,
857
+ "block_size": block_size,
858
+ }
859
+
860
+ if self.use_attention_mask:
861
+ op_args["mask"] = attn_mask
862
+
863
+ if self.phase == "prefill" or self.phase == "image_prefill":
864
+ if not self.use_attention_mask or self.use_position_ids:
865
+ op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
866
+
867
+ attn_op_name = self.get_attn_op_name()
868
+ attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
869
+ if attn_op is None:
870
+ raise ValueError(f"Attention operator {attn_op_name} not found.")
871
+
872
+ attn_output = attn_op(**op_args)
917
873
  attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
918
874
  attn_output = attn_output.transpose(1, 2).contiguous()
919
875
  attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
@@ -921,161 +877,6 @@ class AttentionOp(nn.Module):
921
877
  return attn_output
922
878
 
923
879
 
924
- def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
925
- """Slice cos[cache_position], sin[cache_position] vector for the query."""
926
- if cache_position.shape[0] > 1:
927
- cos_all = []
928
- sin_all = []
929
- for i in range(cache_position.shape[0]):
930
- cos_all.append(cos[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
931
- sin_all.append(sin[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
932
- cos = torch.cat(cos_all, dim=0)
933
- sin = torch.cat(sin_all, dim=0)
934
- else:
935
- cos = cos[cache_position].unsqueeze(unsqueeze_dim)
936
- sin = sin[cache_position].unsqueeze(unsqueeze_dim)
937
-
938
- return cos, sin
939
-
940
-
941
- def rotate_half(x):
942
- """Rotates half the hidden dims of the input."""
943
- x1 = x[..., : x.shape[-1] // 2]
944
- x2 = x[..., x.shape[-1] // 2 :]
945
- return torch.cat((-x2, x1), dim=-1)
946
-
947
-
948
- def apply_rotary_pos_emb(q, k, cos, sin):
949
- """Applies Rotary Position Embedding to the query and key tensors."""
950
- q_embed = (q * cos) + (rotate_half(q) * sin)
951
- k_embed = (k * cos) + (rotate_half(k) * sin)
952
- return q_embed, k_embed
953
-
954
-
955
- def apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim) -> Tuple[torch.Tensor, torch.Tensor]:
956
- # Partial rotary embedding
957
- query_rot, query_pass = (
958
- query_states[..., :ndim],
959
- query_states[..., ndim:],
960
- )
961
- key_rot, key_pass = (
962
- key_states[..., :ndim],
963
- key_states[..., ndim:],
964
- )
965
-
966
- # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
967
- query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
968
-
969
- # [batch_size, seq_length, num_heads, head_dim]
970
- query_states = torch.cat((query_rot, query_pass), dim=-1)
971
- key_states = torch.cat((key_rot, key_pass), dim=-1)
972
- return query_states, key_states
973
-
974
-
975
- class RotaryEmbedding(nn.Module):
976
- """RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
977
-
978
- def __init__(
979
- self,
980
- config: PretrainedConfig,
981
- max_seq_len_cached: int,
982
- ):
983
- super().__init__()
984
-
985
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
986
- rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
987
- else:
988
- rope_type = "default"
989
-
990
- inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, max_seq_len_cached)
991
- cache_position = torch.arange(0, max_seq_len_cached, dtype=torch.float32)
992
- cache_position_expanded = cache_position[:, None]
993
-
994
- if rope_type == "dynamic":
995
- freqs = cache_position_expanded.float() * inv_freq.float()
996
- else:
997
- inv_freq_expanded = inv_freq[None, :]
998
- freqs = cache_position_expanded.float() @ inv_freq_expanded.float()
999
-
1000
- emb = torch.cat((freqs, freqs), dim=-1)
1001
-
1002
- cos = emb.cos() * attention_scaling
1003
- sin = emb.sin() * attention_scaling
1004
-
1005
- self.register_buffer("_cos_cached", cos, persistent=False)
1006
- self.register_buffer("_sin_cached", sin, persistent=False)
1007
-
1008
- def forward(self, x, seq_len):
1009
- return (
1010
- self._cos_cached[:seq_len].to(dtype=x.dtype),
1011
- self._sin_cached[:seq_len].to(dtype=x.dtype),
1012
- )
1013
-
1014
-
1015
- class DecoderOnlyFlashAttention(DecoderOnlyAttention):
1016
- def __init__(self, self_attn, kvcache_partition_len, kvcache_block_size, use_attention_mask, use_position_ids):
1017
- self.kvcache_partition_size = kvcache_partition_len
1018
- super().__init__(
1019
- self_attn=self_attn,
1020
- use_attention_mask=use_attention_mask,
1021
- use_position_ids=use_position_ids,
1022
- kvcache_block_size=kvcache_block_size,
1023
- )
1024
-
1025
- def get_attention(self):
1026
- return FlashAttentionOp(
1027
- self.num_heads,
1028
- self.head_dim,
1029
- self.num_key_value_heads,
1030
- self.kvcache_partition_size,
1031
- self.use_attention_mask,
1032
- self.use_position_ids,
1033
- )
1034
-
1035
- def forward(
1036
- self,
1037
- hidden_states: torch.Tensor,
1038
- attention_mask: torch.Tensor,
1039
- seq_positions: torch.LongTensor,
1040
- past_key_values: Tuple[Tuple[torch.Tensor]],
1041
- cos: Optional[torch.Tensor] = None,
1042
- sin: Optional[torch.Tensor] = None,
1043
- block_tables: Optional[torch.Tensor] = None,
1044
- ):
1045
- batch_size, query_length, _ = hidden_states.size()
1046
-
1047
- query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
1048
-
1049
- query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
1050
- key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1051
- value_states = value_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(
1052
- 1, 2
1053
- )
1054
-
1055
- if hasattr(self, "q_norm") and hasattr(self, "k_norm"):
1056
- query_states = self.q_norm(query_states)
1057
- key_states = self.k_norm(key_states)
1058
-
1059
- if cos is not None and sin is not None:
1060
- query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
1061
-
1062
- attn_output = self.attention(
1063
- query_states,
1064
- key_states,
1065
- value_states,
1066
- attention_mask,
1067
- past_key_state=past_key_values[self.layer_idx][0],
1068
- past_value_state=past_key_values[self.layer_idx][1],
1069
- seq_position=seq_positions,
1070
- scale=self.scale,
1071
- block_tables=block_tables,
1072
- kvcache_block_size=self.kvcache_block_size,
1073
- )
1074
-
1075
- attn_outputs = self.o_proj(attn_output)
1076
- return attn_outputs
1077
-
1078
-
1079
880
  class FlashAttentionOp(AttentionOp):
1080
881
  def __init__(
1081
882
  self,
@@ -1095,6 +896,17 @@ class FlashAttentionOp(AttentionOp):
1095
896
  )
1096
897
  self.kvcache_partition_size = kvcache_partition_len
1097
898
 
899
+ def get_attn_op_name(self):
900
+ phase = "decode" if self.phase == "decode" else "prefill"
901
+ if self.use_attention_mask and not self.use_position_ids:
902
+ attn_op_name = "paged_flash_attn_"
903
+ else:
904
+ attn_op_name = "paged_flash_causal_attn_"
905
+
906
+ attn_op_name += phase
907
+
908
+ return attn_op_name
909
+
1098
910
  def forward(
1099
911
  self,
1100
912
  query_state,
@@ -1106,7 +918,7 @@ class FlashAttentionOp(AttentionOp):
1106
918
  seq_position,
1107
919
  scale,
1108
920
  block_tables,
1109
- kvcache_block_size,
921
+ block_size,
1110
922
  ):
1111
923
  # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
1112
924
  key_state = key_state.unsqueeze(2)
@@ -1127,67 +939,32 @@ class FlashAttentionOp(AttentionOp):
1127
939
  self.head_dim,
1128
940
  )
1129
941
 
1130
- if self.phase == "decode":
1131
- if self.use_attention_mask and not self.use_position_ids:
1132
- attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_decode(
1133
- q=query_state,
1134
- k=key_state,
1135
- v=value_state,
1136
- mask=attn_mask,
1137
- kcache=past_key_state.unsqueeze(2),
1138
- vcache=past_value_state.unsqueeze(2),
1139
- seq=seq_position,
1140
- scale=scale,
1141
- block_table=block_tables,
1142
- block_size=kvcache_block_size,
1143
- partition=self.kvcache_partition_size,
1144
- )
1145
- else:
1146
- attn_output = torch.ops.rbln_custom_ops.paged_flash_causal_attn_decode(
1147
- q=query_state,
1148
- k=key_state,
1149
- v=value_state,
1150
- kcache=past_key_state.unsqueeze(2),
1151
- vcache=past_value_state.unsqueeze(2),
1152
- seq=seq_position,
1153
- scale=scale,
1154
- block_table=block_tables,
1155
- block_size=kvcache_block_size,
1156
- partition=self.kvcache_partition_size,
1157
- mask=attn_mask if self.use_position_ids else None,
1158
- )
1159
- else:
1160
- if self.use_attention_mask and not self.use_position_ids:
1161
- attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_prefill(
1162
- q=query_state,
1163
- k=key_state,
1164
- v=value_state,
1165
- mask=attn_mask,
1166
- kcache=past_key_state.unsqueeze(2),
1167
- vcache=past_value_state.unsqueeze(2),
1168
- seq=seq_position,
1169
- scale=scale,
1170
- block_table=block_tables,
1171
- block_size=kvcache_block_size,
1172
- partition=self.kvcache_partition_size,
1173
- )
1174
- else:
1175
- attn_output = torch.ops.rbln_custom_ops.paged_flash_causal_attn_prefill(
1176
- q=query_state,
1177
- k=key_state,
1178
- v=value_state,
1179
- kcache=past_key_state.unsqueeze(2),
1180
- vcache=past_value_state.unsqueeze(2),
1181
- seq=seq_position,
1182
- scale=scale,
1183
- block_table=block_tables,
1184
- block_size=kvcache_block_size,
1185
- partition=self.kvcache_partition_size,
1186
- is_bidirectional=True if self.phase == "image_prefill" else False,
1187
- mask=attn_mask if self.use_position_ids else None,
1188
- )
1189
-
1190
- # reshape for removing repeat_kv
942
+ op_args = {
943
+ "q": query_state,
944
+ "k": key_state,
945
+ "v": value_state,
946
+ "kcache": past_key_state.unsqueeze(2),
947
+ "vcache": past_value_state.unsqueeze(2),
948
+ "seq": seq_position,
949
+ "scale": scale,
950
+ "block_table": block_tables,
951
+ "block_size": block_size,
952
+ "partition": self.kvcache_partition_size,
953
+ }
954
+
955
+ if self.use_attention_mask:
956
+ op_args["mask"] = attn_mask
957
+
958
+ if self.phase == "prefill" or self.phase == "image_prefill":
959
+ if not self.use_attention_mask or self.use_position_ids:
960
+ op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
961
+
962
+ attn_op_name = self.get_attn_op_name()
963
+ attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
964
+ if attn_op is None:
965
+ raise ValueError(f"Attention operator {attn_op_name} not found.")
966
+
967
+ attn_output = attn_op(**op_args)
1191
968
  attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
1192
969
  attn_output = attn_output.transpose(1, 2).contiguous()
1193
970
  attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
@@ -1196,6 +973,14 @@ class FlashAttentionOp(AttentionOp):
1196
973
 
1197
974
 
1198
975
  class SlidingWindowAttentionOp(AttentionOp):
976
+ def get_attn_op_name(self):
977
+ phase = "decode" if self.phase == "decode" else "prefill"
978
+ if not self.use_attention_mask:
979
+ raise NotImplementedError("Attention mask is needed for sliding window attention.")
980
+
981
+ attn_op_name = "paged_sliding_window_attn_" + phase
982
+ return attn_op_name
983
+
1199
984
  def forward(
1200
985
  self,
1201
986
  query_state: torch.Tensor,
@@ -1226,37 +1011,121 @@ class SlidingWindowAttentionOp(AttentionOp):
1226
1011
  self.head_dim,
1227
1012
  )
1228
1013
 
1229
- if self.phase == "decode":
1230
- attn_output = torch.ops.rbln_custom_ops.paged_sliding_window_attn_decode(
1231
- q=query_state,
1232
- k=key_state,
1233
- v=value_state,
1234
- kcache=past_key_state.unsqueeze(2),
1235
- vcache=past_value_state.unsqueeze(2),
1236
- cache_seq_len=seq_position[0],
1237
- cache_offset=seq_position[1],
1238
- scale=scale,
1239
- block_table=block_tables,
1240
- block_size=block_size,
1241
- )
1242
- else:
1243
- attn_output = torch.ops.rbln_custom_ops.paged_sliding_window_attn_prefill(
1244
- q=query_state,
1245
- k=key_state,
1246
- v=value_state,
1247
- kcache=past_key_state.unsqueeze(2),
1248
- vcache=past_value_state.unsqueeze(2),
1249
- cache_seq_len=seq_position[0],
1250
- cache_offset=seq_position[1],
1251
- scale=scale,
1252
- block_table=block_tables,
1253
- block_size=block_size,
1254
- is_bidirectional=True if self.phase == "image_prefill" else False,
1255
- )
1256
-
1257
- # reshape for removing repeat_kv
1014
+ op_args = {
1015
+ "q": query_state,
1016
+ "k": key_state,
1017
+ "v": value_state,
1018
+ "kcache": past_key_state.unsqueeze(2),
1019
+ "vcache": past_value_state.unsqueeze(2),
1020
+ "cache_seq_len": seq_position[0],
1021
+ "cache_offset": seq_position[1],
1022
+ "scale": scale,
1023
+ "block_table": block_tables,
1024
+ "block_size": block_size,
1025
+ }
1026
+
1027
+ if self.phase == "prefill" or self.phase == "image_prefill":
1028
+ op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
1029
+
1030
+ attn_op_name = self.get_attn_op_name()
1031
+ attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
1032
+ if attn_op is None:
1033
+ raise ValueError(f"Attention operator {attn_op_name} not found.")
1034
+
1035
+ attn_output = attn_op(**op_args)
1258
1036
  attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
1259
1037
  attn_output = attn_output.transpose(1, 2).contiguous()
1260
1038
  attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
1261
1039
 
1262
1040
  return attn_output
1041
+
1042
+
1043
+ class RotaryEmbedding(nn.Module):
1044
+ """RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
1045
+
1046
+ def __init__(
1047
+ self,
1048
+ config: PretrainedConfig,
1049
+ max_seq_len_cached: int,
1050
+ ):
1051
+ super().__init__()
1052
+
1053
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
1054
+ rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
1055
+ else:
1056
+ rope_type = "default"
1057
+
1058
+ inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, max_seq_len_cached)
1059
+ cache_position = torch.arange(0, max_seq_len_cached, dtype=torch.float32)
1060
+ cache_position_expanded = cache_position[:, None]
1061
+
1062
+ if rope_type == "dynamic":
1063
+ freqs = cache_position_expanded.float() * inv_freq.float()
1064
+ else:
1065
+ inv_freq_expanded = inv_freq[None, :]
1066
+ freqs = cache_position_expanded.float() @ inv_freq_expanded.float()
1067
+
1068
+ emb = torch.cat((freqs, freqs), dim=-1)
1069
+
1070
+ cos = emb.cos() * attention_scaling
1071
+ sin = emb.sin() * attention_scaling
1072
+
1073
+ self.register_buffer("_cos_cached", cos, persistent=False)
1074
+ self.register_buffer("_sin_cached", sin, persistent=False)
1075
+
1076
+ def forward(self, x, seq_len):
1077
+ return (
1078
+ self._cos_cached[:seq_len].to(dtype=x.dtype),
1079
+ self._sin_cached[:seq_len].to(dtype=x.dtype),
1080
+ )
1081
+
1082
+
1083
+ def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
1084
+ """Slice cos[cache_position], sin[cache_position] vector for the query."""
1085
+ if cache_position.shape[0] > 1:
1086
+ cos_all = []
1087
+ sin_all = []
1088
+ for i in range(cache_position.shape[0]):
1089
+ cos_all.append(cos[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
1090
+ sin_all.append(sin[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
1091
+ cos = torch.cat(cos_all, dim=0)
1092
+ sin = torch.cat(sin_all, dim=0)
1093
+ else:
1094
+ cos = cos[cache_position].unsqueeze(unsqueeze_dim)
1095
+ sin = sin[cache_position].unsqueeze(unsqueeze_dim)
1096
+
1097
+ return cos, sin
1098
+
1099
+
1100
+ def rotate_half(x):
1101
+ """Rotates half the hidden dims of the input."""
1102
+ x1 = x[..., : x.shape[-1] // 2]
1103
+ x2 = x[..., x.shape[-1] // 2 :]
1104
+ return torch.cat((-x2, x1), dim=-1)
1105
+
1106
+
1107
+ def apply_rotary_pos_emb(q, k, cos, sin):
1108
+ """Applies Rotary Position Embedding to the query and key tensors."""
1109
+ q_embed = (q * cos) + (rotate_half(q) * sin)
1110
+ k_embed = (k * cos) + (rotate_half(k) * sin)
1111
+ return q_embed, k_embed
1112
+
1113
+
1114
+ def apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim) -> Tuple[torch.Tensor, torch.Tensor]:
1115
+ # Partial rotary embedding
1116
+ query_rot, query_pass = (
1117
+ query_states[..., :ndim],
1118
+ query_states[..., ndim:],
1119
+ )
1120
+ key_rot, key_pass = (
1121
+ key_states[..., :ndim],
1122
+ key_states[..., ndim:],
1123
+ )
1124
+
1125
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
1126
+ query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
1127
+
1128
+ # [batch_size, seq_length, num_heads, head_dim]
1129
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
1130
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
1131
+ return query_states, key_states