optimum-rbln 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. optimum/rbln/__init__.py +22 -12
  2. optimum/rbln/__version__.py +16 -1
  3. optimum/rbln/diffusers/__init__.py +22 -2
  4. optimum/rbln/diffusers/models/__init__.py +34 -3
  5. optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
  6. optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +44 -58
  7. optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
  8. optimum/rbln/diffusers/models/controlnet.py +54 -14
  9. optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
  10. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
  11. optimum/rbln/diffusers/models/unets/__init__.py +24 -0
  12. optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +78 -16
  13. optimum/rbln/diffusers/pipelines/__init__.py +22 -2
  14. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +5 -26
  15. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -0
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -0
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -0
  18. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  19. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -0
  20. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +0 -11
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
  22. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
  23. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +14 -6
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +14 -6
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
  30. optimum/rbln/modeling.py +572 -0
  31. optimum/rbln/modeling_alias.py +1 -1
  32. optimum/rbln/modeling_base.py +164 -758
  33. optimum/rbln/modeling_diffusers.py +51 -122
  34. optimum/rbln/transformers/__init__.py +0 -2
  35. optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
  36. optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
  37. optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
  38. optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
  39. optimum/rbln/transformers/models/clip/modeling_clip.py +8 -25
  40. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -3
  41. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +672 -412
  42. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +38 -155
  43. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
  44. optimum/rbln/transformers/models/exaone/exaone_architecture.py +61 -45
  45. optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -2
  46. optimum/rbln/transformers/models/gemma/gemma_architecture.py +33 -104
  47. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
  48. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +3 -2
  49. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +2 -75
  50. optimum/rbln/transformers/models/midm/midm_architecture.py +88 -242
  51. optimum/rbln/transformers/models/midm/modeling_midm.py +6 -6
  52. optimum/rbln/transformers/models/phi/phi_architecture.py +61 -261
  53. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
  54. optimum/rbln/transformers/models/t5/modeling_t5.py +102 -4
  55. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  56. optimum/rbln/transformers/models/whisper/modeling_whisper.py +1 -1
  57. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
  58. optimum/rbln/transformers/utils/rbln_quantization.py +120 -3
  59. optimum/rbln/utils/decorator_utils.py +10 -6
  60. optimum/rbln/utils/hub.py +131 -0
  61. optimum/rbln/utils/import_utils.py +15 -1
  62. optimum/rbln/utils/model_utils.py +53 -0
  63. optimum/rbln/utils/runtime_utils.py +1 -1
  64. optimum/rbln/utils/submodule.py +114 -0
  65. optimum_rbln-0.1.15.dist-info/METADATA +106 -0
  66. {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/RECORD +69 -66
  67. {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
  68. optimum/rbln/transformers/generation/streamers.py +0 -139
  69. optimum/rbln/transformers/generation/utils.py +0 -397
  70. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
  71. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
  72. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
  73. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
  74. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
  75. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
  76. optimum/rbln/utils/context.py +0 -58
  77. optimum_rbln-0.1.13.dist-info/METADATA +0 -120
  78. optimum_rbln-0.1.13.dist-info/entry_points.txt +0 -4
  79. {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -22,20 +22,23 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  import math
25
- from typing import Dict, Optional, Tuple
25
+ from typing import List, Optional, Tuple
26
26
 
27
27
  import torch
28
28
  from torch import nn
29
- from transformers import PretrainedConfig
30
- from transformers.modeling_outputs import (
31
- BaseModelOutputWithPast,
32
- )
29
+ from transformers import PretrainedConfig, PreTrainedModel
30
+ from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
33
31
 
34
32
  from ....utils import logging
35
- from ...cache_utils import RebelDynamicCache
36
33
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
37
34
 
38
35
 
36
+ if is_torch_greater_or_equal_than_2_4:
37
+ register_fake = torch.library.register_fake
38
+ else:
39
+ register_fake = torch.library.impl_abstract
40
+
41
+
39
42
  logger = logging.get_logger(__name__)
40
43
  """
41
44
  ##############################################################################
@@ -83,7 +86,7 @@ def flash_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, partition):
83
86
  return attn_output, torch.empty_like(kcache), torch.empty_like(vcache)
84
87
 
85
88
 
86
- @torch.library.impl_abstract("rbln_custom_ops::flash_attn_decode")
89
+ @register_fake("rbln_custom_ops::flash_attn_decode")
87
90
  def flash_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, partition):
88
91
  return torch.empty_like(q), torch.empty_like(kcache), torch.empty_like(vcache)
89
92
 
@@ -129,7 +132,7 @@ def flash_attn_prefill_cpu(q, k, v, mask, kcache, vcache, batch, seq, partition)
129
132
  return attn_output, torch.empty_like(kcache), torch.empty_like(vcache)
130
133
 
131
134
 
132
- @torch.library.impl_abstract("rbln_custom_ops::flash_attn_prefill")
135
+ @register_fake("rbln_custom_ops::flash_attn_prefill")
133
136
  def flash_attn_prefill_abstract(q, k, v, m, kcache, vcache, batch, seq, partition):
134
137
  return torch.empty_like(q), torch.empty_like(kcache), torch.empty_like(vcache)
135
138
 
@@ -144,501 +147,599 @@ def rbln_cache_update_cpu(cache, value, batch, seq):
144
147
  return updated_cache
145
148
 
146
149
 
147
- @torch.library.impl_abstract("rbln_custom_ops::rbln_cache_update")
150
+ @register_fake("rbln_custom_ops::rbln_cache_update")
148
151
  def rbln_cache_update_abstract(cache, value, batch, seq):
149
152
  return torch.empty_like(cache)
150
153
 
151
154
 
152
- class DecoderOnlyAttention:
153
- def _attn(self, query_state, key_state, value_state, attn_mask, past_key_value, batch_idx=0, is_prefill=False):
154
- # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
155
- key_state = key_state.unsqueeze(2)
156
- value_state = value_state.unsqueeze(2)
157
- attn_mask = attn_mask.unsqueeze(2)
155
+ class DecoderOnlyWrapper(nn.Module):
156
+ """A wrapper class for decoder-only language models that handles RBLN-specific optimizations and requirements.
158
157
 
159
- query_state = query_state.view(
160
- 1,
161
- self.num_key_value_heads,
162
- self.num_heads // self.num_key_value_heads,
163
- -1,
164
- self.head_dim,
165
- )
158
+ This wrapper is designed to:
159
+ 1. Convert Huggingface decoder models for RBLN compilation with static shapes
160
+ 2. Handle input/model mapping and additional information supply (e.g., positional embeddings)
161
+ 3. Manage different attention implementations (standard and flash attention)
162
+ 4. Support both prefill and decode phases
166
163
 
167
- key_state, value_state = past_key_value.update(
168
- key_state, value_state, self.layer_idx, batch_idx, read_first_step=is_prefill
169
- )
164
+ Notes:
165
+ - Wrapper must only receive positional arguments in forward() due to torch.jit.trace dependency
166
+ - Wrapper should not contain neural network graph operations (including memory view handling)
170
167
 
171
- attn_weight = torch.matmul(query_state, key_state.transpose(3, 4)) / math.sqrt(self.head_dim)
172
- attn_weight += attn_mask
173
- attn_weight = nn.functional.softmax(attn_weight, dim=-1, dtype=torch.float32).to(query_state.dtype)
174
- attn_output = torch.matmul(attn_weight, value_state)
168
+ Args:
169
+ causal_lm (PreTrainedModel): The Huggingface causal language model to wrap
170
+ max_seq_len (int): Maximum sequence length for position embeddings and cache sizes
171
+ use_rotary_emb (bool): Whether to use rotary position embeddings
172
+ kvcache_partition_len (Optional[int]): Length of KV cache partitions for flash attention.
173
+ If provided, uses flash attention; if None, uses standard attention
174
+ """
175
175
 
176
- attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim)
177
- attn_output = attn_output.transpose(1, 2).contiguous()
178
- attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
176
+ def __init__(self, causal_lm: PreTrainedModel, max_seq_len, use_rotary_emb: bool, kvcache_partition_len=None):
177
+ super().__init__()
178
+ self.config = causal_lm.config
179
179
 
180
- return attn_output, key_state, value_state
180
+ if use_rotary_emb:
181
+ self.rotary_emb = self.get_rotary_emb(max_seq_len=max_seq_len)
182
+ else:
183
+ self.rotary_emb = None
181
184
 
182
- def forward(
183
- self,
184
- hidden_states: torch.Tensor,
185
- attention_mask: Optional[torch.Tensor] = None,
186
- position_ids: Optional[torch.LongTensor] = None,
187
- past_key_value: Optional[RebelDynamicCache] = None,
188
- batch_index: Optional[torch.Tensor] = None,
189
- output_attentions: bool = False,
190
- cos: Optional[torch.Tensor] = None,
191
- sin: Optional[torch.Tensor] = None,
192
- **kwargs,
193
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
194
- bsz, q_len, _ = hidden_states.size()
195
- query_states = self.q_proj(hidden_states)
196
- key_states = self.k_proj(hidden_states)
197
- value_states = self.v_proj(hidden_states)
185
+ if kvcache_partition_len is not None:
186
+ # WORKAROUND : for passing partition length as a value to the rbln compiler.
187
+ # What is actually used is the shape of this tensor.
188
+ self.attn_impl = "flash_attn"
189
+ logger.info(f"Using flash-attention. (partition length : {kvcache_partition_len})")
190
+ else:
191
+ self.attn_impl = "eager"
192
+ self.kvcache_partition_len = kvcache_partition_len
198
193
 
199
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
200
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
201
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
202
-
203
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
204
-
205
- # Decoder (bsz > 1)
206
- if bsz > 1:
207
- iterate_results = {"key_states": [], "value_states": [], "attn_output": []}
208
- for b in range(bsz):
209
- attn_output, key_state, value_state = DecoderOnlyAttention._attn(
210
- self,
211
- query_states[b].unsqueeze(0),
212
- key_states[b].unsqueeze(0),
213
- value_states[b].unsqueeze(0),
214
- attention_mask[b].unsqueeze(0),
215
- past_key_value,
216
- batch_idx=b,
217
- is_prefill=False,
194
+ self.causal_lm = self.convert_to_rbln_causal_lm(causal_lm)
195
+
196
+ self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or getattr(self.config, "n_layer")
197
+ self._phase = "prefill"
198
+
199
+ def get_rotary_emb(self, max_seq_len):
200
+ return RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
201
+
202
+ def convert_to_rbln_causal_lm(self, causal_lm: PreTrainedModel):
203
+ new_layers = []
204
+ for layer in causal_lm.model.layers:
205
+ if self.attn_impl == "eager":
206
+ new_self_attn = DecoderOnlyAttention(layer.self_attn)
207
+ elif self.attn_impl == "flash_attn":
208
+ new_self_attn = DecoderOnlyFlashAttention(
209
+ layer.self_attn, kvcache_partition_len=self.kvcache_partition_len
218
210
  )
211
+ else:
212
+ raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
213
+
214
+ new_layer = DecoderOnlyLayer(layer, new_self_attn)
215
+ new_layers.append(new_layer)
216
+ new_model = DecoderOnlyModel(causal_lm.model, new_layers)
217
+ new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
218
+ return new_causal_lm
219
219
 
220
- iterate_results["key_states"].append(key_state)
221
- iterate_results["value_states"].append(value_state)
222
- iterate_results["attn_output"].append(attn_output)
220
+ @property
221
+ def phase(self):
222
+ return self._phase
223
223
 
224
- key_states = torch.cat(iterate_results["key_states"], dim=0)
225
- value_states = torch.cat(iterate_results["value_states"], dim=0)
226
- attn_output = torch.cat(iterate_results["attn_output"], dim=0)
227
- # Prefill & Decoder (bsz == 1)
224
+ @phase.setter
225
+ def phase(self, phase: str):
226
+ self._phase = phase
227
+ self.causal_lm.phase = phase
228
+
229
+ def forward(
230
+ self,
231
+ input_ids_or_inputs_embeds,
232
+ attention_mask,
233
+ cache_position,
234
+ batch_position,
235
+ query_position,
236
+ *past_key_values,
237
+ ):
238
+ if input_ids_or_inputs_embeds.ndim == 2:
239
+ # It is input_ids
240
+ input_ids = input_ids_or_inputs_embeds
241
+ inputs_embeds = None
242
+ elif input_ids_or_inputs_embeds.ndim == 3:
243
+ # It is inputs_embeds
244
+ input_ids = None
245
+ inputs_embeds = input_ids_or_inputs_embeds
228
246
  else:
229
- attn_output, key_states, value_states = DecoderOnlyAttention._attn(
230
- self,
231
- query_states,
232
- key_states,
233
- value_states,
234
- attention_mask,
235
- past_key_value,
236
- batch_idx=batch_index,
237
- is_prefill=True,
247
+ raise NotImplementedError(f"Unknown ndim of input : {input_ids_or_inputs_embeds.ndim}")
248
+
249
+ if len(past_key_values) != 2 * self.num_hidden_layers:
250
+ raise ValueError(
251
+ f"Different past_key_values to model's config. {len(past_key_values)} != {self.num_hidden_layers}"
238
252
  )
239
253
 
240
- attn_output = self.o_proj(attn_output)
254
+ seq_len = input_ids_or_inputs_embeds.shape[1]
255
+ if seq_len == 1:
256
+ self.phase = "decode"
257
+ else:
258
+ self.phase = "prefill"
259
+
260
+ # [key, value] * n_layer -> ( (key, value) ) * n_layer
261
+ # cache shape : batch, n_heads, 1, max_seq_len, head_dim
262
+ _past_key_values = []
263
+ for i in range(self.config.num_hidden_layers):
264
+ key_states = past_key_values[i * 2]
265
+ value_states = past_key_values[i * 2 + 1]
266
+ past_key_value = [key_states, value_states]
267
+ _past_key_values.append(past_key_value)
268
+ past_key_values = _past_key_values
269
+
270
+ logit, present_key_values = self.causal_lm(
271
+ input_ids=input_ids,
272
+ inputs_embeds=inputs_embeds,
273
+ attention_mask=attention_mask,
274
+ cache_position=cache_position,
275
+ batch_position=batch_position,
276
+ query_position=query_position,
277
+ past_key_values=past_key_values,
278
+ rotary_emb=self.rotary_emb,
279
+ )
280
+
281
+ # ((key, value)) * n_layer -> [key, value] * n_layer
282
+ _present_key_values = ()
283
+ for i in range(self.num_hidden_layers):
284
+ key_states = present_key_values[i][0]
285
+ value_states = present_key_values[i][1]
286
+ _present_key_values = _present_key_values + (key_states, value_states)
287
+ present_key_values = _present_key_values
241
288
 
242
- if not output_attentions:
243
- attn_weight = None
289
+ # batch_position + query_position is dummy output node to keep the number of outputs
290
+ return logit, present_key_values, batch_position + query_position
244
291
 
245
- return attn_output, attn_weight, key_states, value_states
246
292
 
293
+ class DecoderOnlyForCausalLM(nn.Module):
294
+ """A specialized wrapper for Causal Language Models optimized for RBLN compilation.
247
295
 
248
- class DecoderOnlyFlashAttention:
249
- def forward(
250
- self,
251
- hidden_states: torch.Tensor,
252
- attention_mask: Optional[torch.Tensor] = None,
253
- position_ids: Optional[torch.LongTensor] = None,
254
- past_key_value: Optional[RebelDynamicCache] = None,
255
- batch_index: Optional[torch.Tensor] = None,
256
- output_attentions: bool = False,
257
- cos: Optional[torch.Tensor] = None,
258
- sin: Optional[torch.Tensor] = None,
259
- cache_pos_for_partitions: Optional[torch.Tensor] = None,
260
- kvcache_partition_size: Optional[torch.Tensor] = None,
261
- **kwargs,
262
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
263
- bsz, q_len, _ = hidden_states.size()
264
- query_states = self.q_proj(hidden_states)
265
- key_states = self.k_proj(hidden_states)
266
- value_states = self.v_proj(hidden_states)
296
+ This class adapts Huggingface's CausalLM (or similar models) for RBLN deployment by:
297
+ 1. Managing model phases (prefill/decode) throughout the computation graph
298
+ 2. Handling output shape alignments for static compilation
299
+ 3. Coordinating between the original model and RBLN-optimized components
267
300
 
268
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
269
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
270
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
271
-
272
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
273
-
274
- # Decoder (bsz > 1)
275
- if bsz > 1:
276
- all_key_states = []
277
- all_value_states = []
278
- all_attn_output = []
279
-
280
- for b in range(bsz):
281
- query_state = query_states[b].unsqueeze(0)
282
- attn_mask = attention_mask[b].unsqueeze(0)
283
- key_state = key_states[b].unsqueeze(0)
284
- value_state = value_states[b].unsqueeze(0)
285
-
286
- # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
287
- key_state = key_state.unsqueeze(2)
288
- value_state = value_state.unsqueeze(2)
289
- attn_mask = attn_mask.unsqueeze(2)
290
-
291
- query_state = query_state.view(
292
- 1,
293
- self.num_key_value_heads,
294
- self.num_heads // self.num_key_value_heads,
295
- q_len,
296
- self.head_dim,
297
- )
301
+ The class serves as an intermediate layer between DecoderOnlyWrapper and the core model,
302
+ focusing on maintaining correct model behavior while enabling RBLN-specific optimizations.
298
303
 
299
- # RBLN custom flash attention(decode), dummy batch index
300
- sidx = cache_pos_for_partitions[b][0]
301
- attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_attn_decode(
302
- query_state,
303
- key_state,
304
- value_state,
305
- attn_mask,
306
- past_key_value.key_cache[self.layer_idx].unsqueeze(2),
307
- past_key_value.value_cache[self.layer_idx].unsqueeze(2),
308
- sidx,
309
- kvcache_partition_size,
310
- )
304
+ Args:
305
+ causal_lm (PreTrainedModel): Original Huggingface causal language model
306
+ model (DecoderOnlyModel): RBLN-optimized model instance
311
307
 
312
- # reshape for removing repeat_kv
313
- attn_output = attn_output.view(1, self.num_heads, q_len, self.head_dim)
314
- attn_output = attn_output.transpose(1, 2).contiguous()
315
- attn_output = attn_output.reshape(1, q_len, self.num_heads * self.head_dim)
308
+ Attributes:
309
+ config: Configuration from the original causal language model
310
+ _original_mod: Reference to the original model for components like lm_head
311
+ model: RBLN-optimized decoder model instance
312
+ _phase: Current processing phase ("prefill" or "decode")
313
+ """
316
314
 
317
- all_key_states.append(key_state)
318
- all_value_states.append(value_state)
319
- all_attn_output.append(attn_output)
315
+ def __init__(self, causal_lm: PreTrainedModel, model):
316
+ super().__init__()
317
+ self.config = causal_lm.config
318
+ self._original_mod = causal_lm
319
+ self.model = model
320
+ self._phase = "prefill"
320
321
 
321
- key_states = torch.cat(all_key_states, dim=0)
322
- value_states = torch.cat(all_value_states, dim=0)
323
- attn_output = torch.cat(all_attn_output, dim=0)
322
+ @property
323
+ def phase(self):
324
+ return self._phase
324
325
 
325
- else:
326
- # reshape for removing repeat_kv
327
- key_states = key_states.unsqueeze(2)
328
- value_states = value_states.unsqueeze(2)
329
- attention_mask = attention_mask.unsqueeze(2)
330
- query_states = query_states.view(
331
- 1,
332
- self.num_key_value_heads,
333
- self.num_heads // self.num_key_value_heads,
334
- q_len,
335
- self.head_dim,
336
- )
326
+ @phase.setter
327
+ def phase(self, phase: str):
328
+ self._phase = phase
329
+ self.model.phase = phase
337
330
 
338
- assert batch_index.dim() == 0
339
- assert not output_attentions
340
- bidx = batch_index
341
- sidx = cache_pos_for_partitions[0][0]
342
- attn_output, key_states, value_states = torch.ops.rbln_custom_ops.flash_attn_prefill(
343
- query_states,
344
- key_states,
345
- value_states,
346
- attention_mask,
347
- past_key_value.key_cache[self.layer_idx].unsqueeze(2),
348
- past_key_value.value_cache[self.layer_idx].unsqueeze(2),
349
- bidx,
350
- sidx,
351
- kvcache_partition_size,
352
- )
331
+ def forward(
332
+ self,
333
+ input_ids: torch.Tensor = None,
334
+ inputs_embeds: torch.Tensor = None,
335
+ attention_mask: torch.Tensor = None,
336
+ cache_position: torch.Tensor = None,
337
+ batch_position: torch.Tensor = None,
338
+ query_position: torch.Tensor = None,
339
+ past_key_values: Tuple[Tuple[torch.Tensor]] = None,
340
+ rotary_emb: nn.Module = None,
341
+ ):
342
+ # outputs
343
+ hidden_states, present_key_values = self.model(
344
+ input_ids=input_ids,
345
+ inputs_embeds=inputs_embeds,
346
+ attention_mask=attention_mask,
347
+ cache_position=cache_position,
348
+ batch_position=batch_position,
349
+ past_key_values=past_key_values,
350
+ rotary_emb=rotary_emb,
351
+ )
353
352
 
354
- # reshape for removing repeat_kv
355
- attn_output = attn_output.view(1, self.num_heads, q_len, self.head_dim)
356
- attn_output = attn_output.transpose(1, 2).contiguous()
357
- attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
353
+ if self.phase == "prefill":
354
+ hidden_states = hidden_states[:, query_position.to(torch.int).unsqueeze(0)]
358
355
 
359
- attn_output = self.o_proj(attn_output)
356
+ logits = self._original_mod.lm_head(hidden_states)
357
+ output = (logits, present_key_values)
358
+ return output
360
359
 
361
- if not output_attentions:
362
- attn_weight = None
363
360
 
364
- return attn_output, attn_weight, key_states, value_states
361
+ class DecoderOnlyModel(nn.Module):
362
+ """A modified decoder-only model implementation optimized for RBLN compilation.
365
363
 
364
+ Args:
365
+ model: Original Huggingface model to adapt
366
+ layers (List[DecoderOnlyLayer]): Modified transformer layers optimized for RBLN
366
367
 
367
- DECODERONLY_ATTENTION_CLASSES = {
368
- "eager": DecoderOnlyAttention,
369
- "flash_attn_rbln": DecoderOnlyFlashAttention,
370
- # "sdpa": DecoderOnlySdpaAttention,
371
- }
368
+ Attributes:
369
+ _original_mod: Reference to original Huggingface model
370
+ layers: ModuleList of RBLN-optimized transformer layers
371
+ _phase: Current processing phase ("prefill" or "decode")
372
+ """
372
373
 
374
+ mask_fmin = torch.finfo(torch.float16).min
373
375
 
374
- class DecoderOnlyWrapper(torch.nn.Module):
375
- def __init__(self, model, max_seq_len, kvcache_partition_len=None):
376
+ def __init__(self, model, layers: List["DecoderOnlyLayer"]):
376
377
  super().__init__()
377
- self.config = model.config
378
- self.model = model.model
379
- self.lm_head = model.lm_head
380
- self.max_seq_len = max_seq_len
381
- self.rotary_emb = RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
378
+ self._original_mod = model
379
+ self.layers = nn.ModuleList(layers)
380
+ self._phase = "prefill"
382
381
 
383
- if kvcache_partition_len is not None:
384
- # WORKAROUND : for passing partition length as a value to the rbln compiler.
385
- # What is actually used is the shape of this tensor.
386
- self.kvcache_partition_size = torch.zeros(kvcache_partition_len, dtype=torch.int32)
387
- self.attn_implementation = "flash_attn_rbln"
388
- logger.info(f"Using rbln-flash-attention. (partition length : {kvcache_partition_len})")
389
- else:
390
- self.kvcache_partition_size = None
391
- self.attn_implementation = "eager"
382
+ @property
383
+ def phase(self):
384
+ return self._phase
385
+
386
+ @phase.setter
387
+ def phase(self, phase: str):
388
+ self._phase = phase
389
+ for layer in self.layers:
390
+ layer.phase = phase
392
391
 
393
- def get_forward_dict(self):
394
- forward_dict = {
395
- "wrapper": DecoderOnlyModel.forward,
396
- "model": DecoderOnlyDecoderLayer.forward,
397
- "decoder_layer": DECODERONLY_ATTENTION_CLASSES[self.attn_implementation].forward,
398
- }
399
- return forward_dict
392
+ @property
393
+ def hidden_multiplier(self):
394
+ return 1
395
+
396
+ def get_last_layernorm(self) -> nn.LayerNorm:
397
+ return self._original_mod.norm
398
+
399
+ def get_embedding(self) -> nn.Embedding:
400
+ return self._original_mod.embed_tokens
401
+
402
+ def get_pos_embedding(self) -> nn.Embedding:
403
+ raise NotImplementedError(
404
+ "The 'get_pos_embedding' method is not implemented. Please define this method in a subclass."
405
+ )
400
406
 
401
407
  def forward(
402
408
  self,
403
- input_ids_or_inputs_embeds,
404
- attention_mask,
405
- cache_position,
406
- batch_position,
407
- query_idx,
408
- *past_key_values,
409
+ input_ids: torch.Tensor = None,
410
+ inputs_embeds: torch.Tensor = None,
411
+ attention_mask: torch.Tensor = None,
412
+ cache_position: torch.Tensor = None,
413
+ batch_position: torch.Tensor = None,
414
+ past_key_values: Tuple[Tuple[torch.Tensor]] = None,
415
+ rotary_emb: nn.Module = None,
409
416
  ):
410
- if input_ids_or_inputs_embeds.ndim == 2:
411
- # input_ids
412
- input_ids = input_ids_or_inputs_embeds
413
- inputs_embeds = None
414
- elif input_ids_or_inputs_embeds.ndim == 3:
415
- # inputs_embeds
416
- input_ids = None
417
- inputs_embeds = input_ids_or_inputs_embeds
418
- else:
419
- raise NotImplementedError(f"Unknown ndim of input : {input_ids_or_inputs_embeds.ndim}")
417
+ # retrieve input_ids and inputs_embeds
418
+ if (input_ids is None) ^ (inputs_embeds is not None):
419
+ raise ValueError(
420
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
421
+ )
420
422
 
421
- # Formatting list of past_kv to DynamicCache class.
422
- past_key_values = RebelDynamicCache.from_input_format(
423
- cache_position,
424
- self.config.num_hidden_layers,
425
- *past_key_values,
426
- )
423
+ # embed positions
424
+ if inputs_embeds is None:
425
+ inputs_embeds = self.get_embedding()(input_ids)
427
426
 
428
- batch_size = input_ids_or_inputs_embeds.size()[0]
429
- seq_len = input_ids_or_inputs_embeds.size()[1]
430
-
431
- if self.attn_implementation == "eager":
432
- cache_pos_for_partitions = None
433
- elif self.attn_implementation == "flash_attn_rbln":
434
- p_len = self.kvcache_partition_size.size()[0]
435
- num_partition = self.max_seq_len // p_len
436
- if self.max_seq_len % p_len > 0:
437
- raise ValueError(
438
- f"The partition length({p_len}) must be exactly divisible by the max_seq_len({self.max_seq_len})."
439
- )
440
- cache_pos_for_partitions = torch.zeros((batch_size, num_partition), dtype=torch.int32)
427
+ hidden_states = inputs_embeds * self.hidden_multiplier
428
+ attention_mask = (1 - attention_mask) * self.mask_fmin
441
429
 
442
- if batch_size > 1: # decode
443
- for b_idx in range(batch_size):
444
- decoding_step = cache_position[b_idx]
445
- cache_pos = decoding_step
446
- for p_idx in range(num_partition):
447
- input_0 = torch.tensor(cache_pos - p_len * p_idx, dtype=torch.int32)
448
- input_1 = torch.tensor(p_len, dtype=torch.int32)
449
- min = torch.minimum(input_0, input_1)
450
- cache_pos_for_partition = torch.maximum(min, torch.tensor(0, dtype=torch.int32))
451
- cache_pos_for_partitions[b_idx][p_idx] = cache_pos_for_partition
452
- else: # prefill
453
- cache_pos = cache_position[0][0]
454
- for p_idx in range(num_partition):
455
- input_0 = torch.tensor(cache_pos - p_len * p_idx, dtype=torch.int32)
456
- input_1 = torch.tensor(p_len, dtype=torch.int32)
457
- min = torch.minimum(input_0, input_1)
458
- cache_pos_for_partition = torch.maximum(min, torch.tensor(0, dtype=torch.int32))
459
- cache_pos_for_partitions[0][p_idx] = cache_pos_for_partition
430
+ # get cos,sin vector if needed
431
+ if rotary_emb is not None:
432
+ cos, sin = rotary_emb(hidden_states, attention_mask.shape[-1]) # dtype carrier, max_seq_len
433
+ cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, cache_position)
460
434
  else:
461
- raise NotImplementedError(f"Unknown attn_implementation: {self.attn_implementation}")
435
+ batch_size = inputs_embeds.shape[0]
436
+ if cache_position.shape[0] > 1:
437
+ position_embeds = []
438
+ for b_idx in range(batch_size):
439
+ position_embed = self.get_pos_embedding()(cache_position[b_idx])
440
+ position_embeds.append(position_embed)
441
+
442
+ position_embeds = torch.cat(position_embeds, dim=0).unsqueeze(1)
443
+ else:
444
+ position_embeds = self.get_pos_embedding()(cache_position)
445
+ hidden_states = hidden_states + position_embeds
446
+ cos, sin = None, None
447
+
448
+ # (batch, seq_len) -> (batch,)
449
+ current_steps = cache_position[:, 0]
450
+
451
+ present_key_values = past_key_values
452
+ for layer in self.layers:
453
+ hidden_states, present_key_values = layer(
454
+ hidden_states=hidden_states,
455
+ attention_mask=attention_mask,
456
+ current_steps=current_steps,
457
+ batch_position=batch_position,
458
+ past_key_values=present_key_values,
459
+ cos=cos,
460
+ sin=sin,
461
+ )
462
462
 
463
- forward_dict = self.get_forward_dict()
464
- outputs = forward_dict["wrapper"](
465
- self.model,
466
- input_ids=input_ids,
467
- inputs_embeds=inputs_embeds,
468
- attention_mask=attention_mask,
469
- position_ids=cache_position,
470
- past_key_values=past_key_values,
471
- batch_ids=batch_position,
472
- rotary_pos_emb=self.rotary_emb,
473
- cache_pos_for_partitions=cache_pos_for_partitions,
474
- kvcache_partition_size=self.kvcache_partition_size,
475
- forward_dict=forward_dict,
476
- )
463
+ hidden_states = self.get_last_layernorm()(hidden_states)
464
+ return hidden_states, present_key_values
465
+
466
+
467
+ class DecoderOnlyLayer(nn.Module):
468
+ """A single transformer layer adapted for RBLN compilation with static shapes.
469
+
470
+ This layer implements a modified transformer block that includes:
471
+ 1. Self-attention mechanism (either standard or flash attention)
472
+ 2. Feed-forward network (FFN)
473
+ 3. Layer normalization
474
+ 4. Residual connections
477
475
 
478
- hidden_states = outputs[0]
479
- if seq_len != 1:
480
- hidden_states = hidden_states[:, query_idx.to(torch.int).unsqueeze(0)]
476
+ The layer is specifically designed to:
477
+ - Support compilation to RBLN custom ops
478
+ - Maintain static tensor shapes throughout computations
479
+ - Handle both prefill and decode phases efficiently
480
+ - Manage attention state transitions properly
481
481
 
482
- logits = self.lm_head(hidden_states)
482
+ Args:
483
+ layer: Original transformer layer module to wrap
484
+ self_attn (DecoderOnlyAttention): Modified attention module optimized for RBLN
483
485
 
484
- output = (logits,) + outputs[1:]
486
+ Attributes:
487
+ _original_mod: Reference to original layer for accessing components
488
+ self_attn: Modified attention mechanism mapped to RBLN ops at compile time
489
+ phase: Current operation phase ("prefill" or "decode")
490
+ """
491
+
492
+ def __init__(self, layer, self_attn: "DecoderOnlyAttention"):
493
+ super().__init__()
494
+ self._original_mod = layer
495
+ self.self_attn = self_attn
496
+ self._phase = "prefill"
497
+
498
+ @property
499
+ def phase(self):
500
+ return self._phase
485
501
 
486
- return output, batch_position + query_idx
502
+ @phase.setter
503
+ def phase(self, phase: str):
504
+ self._phase = phase
505
+ self.self_attn.phase = phase
487
506
 
507
+ def get_pre_attention_layernorm(self) -> nn.LayerNorm:
508
+ return self._original_mod.input_layernorm
509
+
510
+ def get_post_attention_layernorm(self) -> nn.LayerNorm:
511
+ return self._original_mod.post_attention_layernorm
488
512
 
489
- class DecoderOnlyDecoderLayer:
490
513
  def forward(
491
514
  self,
492
515
  hidden_states: torch.Tensor,
493
- layer_idx: int,
494
- attention_mask: Optional[torch.Tensor] = None,
495
- position_ids: Optional[torch.LongTensor] = None,
496
- past_key_value: Optional[RebelDynamicCache] = None,
497
- output_attentions: Optional[bool] = None,
498
- use_cache: Optional[bool] = None,
499
- batch_ids: Optional[torch.Tensor] = None,
516
+ attention_mask: torch.Tensor,
517
+ current_steps: torch.LongTensor,
518
+ batch_position: torch.Tensor,
519
+ past_key_values: Tuple[Tuple[torch.Tensor]],
500
520
  cos: Optional[torch.Tensor] = None,
501
521
  sin: Optional[torch.Tensor] = None,
502
- cache_pos_for_partitions: Optional[torch.Tensor] = None,
503
- kvcache_partition_size: Optional[torch.Tensor] = None,
504
- forward_dict: Optional[Dict[str, classmethod]] = None,
505
- **kwargs,
506
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
522
+ ):
507
523
  residual = hidden_states
508
524
 
509
- hidden_states = self.input_layernorm(hidden_states)
525
+ hidden_states = self.get_pre_attention_layernorm()(hidden_states)
510
526
 
511
- hidden_states, self_attn_weight, k, v = forward_dict["decoder_layer"](
512
- self.self_attn,
527
+ hidden_states, present_key_values = self.self_attn(
513
528
  hidden_states=hidden_states,
514
529
  attention_mask=attention_mask,
515
- position_ids=position_ids,
516
- past_key_value=past_key_value,
517
- output_attentions=output_attentions,
518
- batch_index=batch_ids,
519
- use_cache=use_cache,
530
+ current_steps=current_steps,
531
+ batch_position=batch_position,
532
+ past_key_values=past_key_values,
520
533
  cos=cos,
521
534
  sin=sin,
522
- cache_pos_for_partitions=cache_pos_for_partitions,
523
- kvcache_partition_size=kvcache_partition_size,
524
- **kwargs,
525
535
  )
526
- past_key_value.assign(k, v, layer_idx)
527
-
528
536
  hidden_states = residual + hidden_states
529
537
 
530
538
  # Fully Connected
531
539
  residual = hidden_states
532
- hidden_states = self.post_attention_layernorm(hidden_states)
533
- hidden_states = self.mlp(hidden_states)
540
+ hidden_states = self.get_post_attention_layernorm()(hidden_states)
541
+ hidden_states = self._original_mod.mlp(hidden_states)
534
542
  hidden_states = residual + hidden_states
535
543
 
536
- outputs = (hidden_states,)
544
+ return hidden_states, present_key_values
537
545
 
538
- if output_attentions:
539
- outputs += (self_attn_weight,)
540
546
 
541
- if use_cache:
542
- outputs += (past_key_value,)
547
+ class DecoderOnlyAttention(nn.Module):
548
+ """Attention implementation for decoder-only models optimized for RBLN compilation.
543
549
 
544
- return outputs
550
+ This class implements a modified version of the standard attention mechanism that:
551
+ 1. Supports static shape requirements for RBLN compilation
552
+ 2. Handles explicit batch and position management
545
553
 
554
+ Args:
555
+ self_attn: Original attention module from the base model
556
+ """
546
557
 
547
- class DecoderOnlyModel:
548
- def forward(
558
+ def __init__(self, self_attn):
559
+ super().__init__()
560
+ self._original_mod = self_attn
561
+ self.layer_idx = self_attn.layer_idx
562
+ self.num_heads = self._original_mod.num_heads
563
+ self.head_dim = self._original_mod.head_dim
564
+ self.phase = "prefill"
565
+ self.__post_init__()
566
+
567
+ def __post_init__(self):
568
+ self.q_proj = self._original_mod.q_proj
569
+ self.k_proj = self._original_mod.k_proj
570
+ self.v_proj = self._original_mod.v_proj
571
+ self.o_proj = self._original_mod.o_proj
572
+ self.num_key_value_heads = self._original_mod.num_key_value_heads
573
+
574
+ def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
575
+ """Projects input hidden states into query, key, and value representations.
576
+
577
+ Args:
578
+ hidden_states: Input tensor of shape [batch_size, seq_len, hidden_dim]
579
+
580
+ Returns:
581
+ Tuple of (query_states, key_states, value_states)
582
+ """
583
+ query_states = self.q_proj(hidden_states)
584
+ key_states = self.k_proj(hidden_states)
585
+ value_states = self.v_proj(hidden_states)
586
+ return query_states, key_states, value_states
587
+
588
+ def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
589
+ return apply_rotary_pos_emb(query_states, key_states, cos, sin)
590
+
591
+ def rbln_attention(
549
592
  self,
550
- input_ids: torch.LongTensor = None,
551
- attention_mask: Optional[torch.Tensor] = None,
552
- position_ids: Optional[torch.LongTensor] = None,
553
- past_key_values: Optional[RebelDynamicCache] = None,
554
- batch_ids: Optional[torch.Tensor] = None,
555
- inputs_embeds: Optional[torch.FloatTensor] = None,
556
- use_cache: Optional[bool] = True,
557
- output_attentions: Optional[bool] = False,
558
- output_hidden_states: Optional[bool] = False,
559
- cache_pos_for_partitions: Optional[torch.Tensor] = None,
560
- kvcache_partition_size: Optional[torch.Tensor] = None,
561
- forward_dict: Optional[Dict[str, classmethod]] = None,
562
- rotary_pos_emb=None,
563
- ) -> BaseModelOutputWithPast:
564
- # retrieve input_ids and inputs_embeds
565
- if (input_ids is None) ^ (inputs_embeds is not None):
566
- raise ValueError(
567
- "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
568
- )
593
+ query_state,
594
+ key_state,
595
+ value_state,
596
+ attn_mask,
597
+ batch_idx,
598
+ past_key_state,
599
+ past_value_state,
600
+ current_step,
601
+ # below are designed for Midm, GPT which requires to support scaling for attention weights
602
+ # TODO(jongho): Merge and manage scales generally
603
+ layer_idx=None,
604
+ scale_attn_weights: bool = None,
605
+ scale_attn_by_inverse_layer_idx: bool = None,
606
+ scale_qk_by_inverse_layer_idx: bool = None,
607
+ ):
608
+ """Compute attention with static shapes and explicit cache management.
609
+
610
+ Args:
611
+ query_state: Query tensor [1, num_heads, 1, head_dim]
612
+ key_state: Key tensor [1, num_heads, seq_len, head_dim]
613
+ value_state: Value tensor [1, num_heads, seq_len, head_dim]
614
+ attn_mask: Attention mask tensor
615
+ batch_idx: Batch index for cache lookup
616
+ past_key_state: Previous key cache states
617
+ past_value_state: Previous value cache states
618
+ current_step: Current position in sequence
619
+
620
+ Returns:
621
+ Tuple of (attention_output, key_state, value_state)
622
+ """
623
+ # Implementation details.
624
+ # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
625
+ key_state = key_state.unsqueeze(2) # 1, 32, 1, 128, 128
626
+ value_state = value_state.unsqueeze(2)
627
+ attn_mask = attn_mask.unsqueeze(2)
569
628
 
570
- # embed positions
571
- if inputs_embeds is None:
572
- inputs_embeds = self.embed_tokens(input_ids)
573
-
574
- hidden_states = inputs_embeds
575
- attention_mask = (1 - attention_mask) * torch.finfo(torch.float16).min
576
-
577
- # get cos,sin vector
578
- cos, sin = rotary_pos_emb(inputs_embeds, attention_mask.shape[-1])
579
- cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
580
-
581
- # decoder layers
582
- all_hidden_states = () if output_hidden_states else None
583
- all_self_attns = () if output_attentions else None
584
-
585
- for layer_idx, decoder_layer in enumerate(self.layers):
586
- if output_hidden_states:
587
- all_hidden_states += (hidden_states,)
588
- layer_outputs = forward_dict["model"](
589
- decoder_layer,
590
- hidden_states,
591
- layer_idx,
592
- attention_mask=attention_mask,
593
- position_ids=position_ids,
594
- past_key_value=past_key_values,
595
- output_attentions=output_attentions,
596
- use_cache=use_cache,
597
- batch_ids=batch_ids,
598
- cos=cos,
599
- sin=sin,
600
- cache_pos_for_partitions=cache_pos_for_partitions,
601
- kvcache_partition_size=kvcache_partition_size,
602
- forward_dict=forward_dict,
603
- )
629
+ query_state = query_state.view(
630
+ 1,
631
+ self.num_key_value_heads,
632
+ self.num_heads // self.num_key_value_heads,
633
+ -1, # seq len
634
+ self.head_dim,
635
+ ) #
636
+
637
+ kend = current_step + key_state.shape[-2]
638
+ vend = current_step + value_state.shape[-2]
639
+
640
+ key_state = (
641
+ past_key_state[batch_idx]
642
+ .unsqueeze(0)
643
+ .unsqueeze(2)
644
+ .slice_scatter(key_state, dim=-2, start=current_step, end=kend)
645
+ )
646
+ value_state = (
647
+ past_value_state[batch_idx]
648
+ .unsqueeze(0)
649
+ .unsqueeze(2)
650
+ .slice_scatter(value_state, dim=-2, start=current_step, end=vend)
651
+ )
652
+
653
+ attn_weight = torch.matmul(query_state, key_state.transpose(3, 4))
654
+ attn_weight = attn_weight / math.sqrt(self.head_dim)
655
+
656
+ if layer_idx is not None and (scale_attn_by_inverse_layer_idx or scale_qk_by_inverse_layer_idx):
657
+ attn_weight = attn_weight / float(layer_idx + 1)
658
+
659
+ attn_weight += attn_mask
604
660
 
605
- hidden_states = layer_outputs[0]
661
+ if layer_idx is not None and scale_qk_by_inverse_layer_idx:
662
+ attn_weight = attn_weight * float(layer_idx + 1)
606
663
 
607
- updated_cache = layer_outputs[2 if output_attentions else 1]
664
+ attn_weight = nn.functional.softmax(attn_weight, dim=-1)
608
665
 
609
- if output_attentions:
610
- all_self_attns += (layer_outputs[1],)
666
+ attn_output = torch.matmul(attn_weight, value_state)
611
667
 
612
- hidden_states = self.norm(hidden_states)
668
+ attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim)
669
+ attn_output = attn_output.transpose(1, 2).contiguous()
670
+ attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
613
671
 
614
- # add hidden states from the last decoder layer
615
- if output_hidden_states:
616
- all_hidden_states += (hidden_states,)
672
+ return attn_output, key_state, value_state
617
673
 
618
- # convert RebelDynamicCache to legacy Tuple[Tuple[torch.Tensor]]
619
- next_cache = updated_cache.to_legacy_cache()
674
+ def forward(
675
+ self,
676
+ hidden_states: torch.Tensor,
677
+ attention_mask: torch.Tensor,
678
+ current_steps: torch.LongTensor,
679
+ batch_position: torch.Tensor,
680
+ past_key_values: Tuple[Tuple[torch.Tensor]],
681
+ cos: Optional[torch.Tensor] = None, # (batch, 1, prefill_size, head_dim)
682
+ sin: Optional[torch.Tensor] = None,
683
+ ):
684
+ batch_size, query_length, _ = hidden_states.size()
620
685
 
621
- return BaseModelOutputWithPast(
622
- last_hidden_state=hidden_states,
623
- past_key_values=next_cache,
624
- hidden_states=all_hidden_states,
625
- attentions=all_self_attns,
686
+ query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
687
+
688
+ query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
689
+ key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
690
+ value_states = value_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(
691
+ 1, 2
626
692
  )
693
+ # b, num_head, query, head_dim
694
+
695
+ if cos is not None and sin is not None:
696
+ query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
697
+
698
+ if batch_size > 1 and self.phase == "prefill":
699
+ raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
700
+
701
+ _key_states = []
702
+ _value_states = []
703
+ _attn_outputs = []
704
+ for b in range(batch_size):
705
+ current_step = current_steps[b]
706
+ attn_output, key_state, value_state = self.rbln_attention(
707
+ query_states[b].unsqueeze(0),
708
+ key_states[b].unsqueeze(0),
709
+ value_states[b].unsqueeze(0),
710
+ attention_mask[b].unsqueeze(0)
711
+ if self.phase == "decode"
712
+ else attention_mask, # TODO(jongho): fix when msoftmax is supported
713
+ past_key_state=past_key_values[self.layer_idx][0],
714
+ past_value_state=past_key_values[self.layer_idx][1],
715
+ batch_idx=b if self.phase == "decode" else batch_position,
716
+ current_step=current_step,
717
+ )
718
+ _key_states.append(key_state)
719
+ _value_states.append(value_state)
720
+ _attn_outputs.append(attn_output)
721
+ key_states = torch.cat(_key_states, dim=0)
722
+ value_states = torch.cat(_value_states, dim=0)
723
+ attn_outputs = torch.cat(_attn_outputs, dim=0)
627
724
 
725
+ attn_outputs = self.o_proj(attn_outputs)
726
+ past_key_values[self.layer_idx] = key_states, value_states
727
+ return attn_outputs, past_key_values
628
728
 
629
- def slice_and_unsqueeze_cos_sin(cos, sin, position_ids, unsqueeze_dim=1):
630
- """Slice cos[position_ids], sin[position_ids] vector for the query."""
631
- if position_ids.shape[0] > 1:
729
+
730
+ def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
731
+ """Slice cos[cache_position], sin[cache_position] vector for the query."""
732
+ if cache_position.shape[0] > 1:
632
733
  cos_all = []
633
734
  sin_all = []
634
- for i in range(position_ids.shape[0]):
635
- cos_all.append(cos[position_ids[i : i + 1]].unsqueeze(unsqueeze_dim))
636
- sin_all.append(sin[position_ids[i : i + 1]].unsqueeze(unsqueeze_dim))
735
+ for i in range(cache_position.shape[0]):
736
+ cos_all.append(cos[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
737
+ sin_all.append(sin[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
637
738
  cos = torch.cat(cos_all, dim=0)
638
739
  sin = torch.cat(sin_all, dim=0)
639
740
  else:
640
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
641
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
741
+ cos = cos[cache_position].unsqueeze(unsqueeze_dim)
742
+ sin = sin[cache_position].unsqueeze(unsqueeze_dim)
642
743
 
643
744
  return cos, sin
644
745
 
@@ -658,6 +759,26 @@ def apply_rotary_pos_emb(q, k, cos, sin):
658
759
  return q_embed, k_embed
659
760
 
660
761
 
762
+ def apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim) -> Tuple[torch.Tensor, torch.Tensor]:
763
+ # Partial rotary embedding
764
+ query_rot, query_pass = (
765
+ query_states[..., :ndim],
766
+ query_states[..., ndim:],
767
+ )
768
+ key_rot, key_pass = (
769
+ key_states[..., :ndim],
770
+ key_states[..., ndim:],
771
+ )
772
+
773
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
774
+ query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
775
+
776
+ # [batch_size, seq_length, num_heads, head_dim]
777
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
778
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
779
+ return query_states, key_states
780
+
781
+
661
782
  class RotaryEmbedding(nn.Module):
662
783
  """RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
663
784
 
@@ -674,14 +795,14 @@ class RotaryEmbedding(nn.Module):
674
795
  rope_type = "default"
675
796
 
676
797
  inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, max_seq_len_cached)
677
- position_ids = torch.arange(0, max_seq_len_cached, dtype=torch.float32)
678
- position_ids_expanded = position_ids[:, None]
798
+ cache_position = torch.arange(0, max_seq_len_cached, dtype=torch.float32)
799
+ cache_position_expanded = cache_position[:, None]
679
800
 
680
801
  if rope_type == "dynamic":
681
- freqs = position_ids_expanded.float() * inv_freq.float()
802
+ freqs = cache_position_expanded.float() * inv_freq.float()
682
803
  else:
683
804
  inv_freq_expanded = inv_freq[None, :]
684
- freqs = position_ids_expanded.float() @ inv_freq_expanded.float()
805
+ freqs = cache_position_expanded.float() @ inv_freq_expanded.float()
685
806
 
686
807
  emb = torch.cat((freqs, freqs), dim=-1)
687
808
 
@@ -696,3 +817,142 @@ class RotaryEmbedding(nn.Module):
696
817
  self._cos_cached[:seq_len].to(dtype=x.dtype),
697
818
  self._sin_cached[:seq_len].to(dtype=x.dtype),
698
819
  )
820
+
821
+
822
+ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
823
+ def __init__(self, self_attn, kvcache_partition_len):
824
+ super().__init__(self_attn=self_attn)
825
+ self.kvcache_partition_size = torch.zeros(kvcache_partition_len, dtype=torch.int32)
826
+
827
+ def get_cache_pos_for_partitions(self, current_steps, batch_size, max_seq_len):
828
+ partition_len = self.kvcache_partition_size.size()[0]
829
+ num_partition = max_seq_len // partition_len
830
+ cache_pos_for_partitions = torch.zeros((batch_size, num_partition), dtype=torch.int32)
831
+ if self.phase == "decode":
832
+ for b_idx in range(batch_size):
833
+ cache_pos = current_steps[b_idx]
834
+ for p_idx in range(num_partition):
835
+ cache_pos_for_partitions[b_idx][p_idx] = torch.clamp(
836
+ cache_pos - partition_len * p_idx, 0, partition_len
837
+ )
838
+ else: # prefill
839
+ cache_pos = current_steps[0]
840
+ for p_idx in range(num_partition):
841
+ cache_pos_for_partitions[0][p_idx] = torch.clamp(cache_pos - partition_len * p_idx, 0, partition_len)
842
+
843
+ return cache_pos_for_partitions
844
+
845
+ def rbln_flash_attention(
846
+ self,
847
+ query_state,
848
+ key_state,
849
+ value_state,
850
+ attn_mask,
851
+ batch_idx,
852
+ past_key_state,
853
+ past_value_state,
854
+ cache_pos_for_partitions,
855
+ ):
856
+ # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
857
+ key_state = key_state.unsqueeze(2) # 1, 32, 1, 128, 128
858
+ value_state = value_state.unsqueeze(2)
859
+ attn_mask = attn_mask.unsqueeze(2)
860
+
861
+ query_state = query_state.view(
862
+ 1,
863
+ self.num_key_value_heads,
864
+ self.num_heads // self.num_key_value_heads,
865
+ -1, # seq len
866
+ self.head_dim,
867
+ )
868
+
869
+ # RBLN custom flash attention(decode), dummy batch index
870
+ if self.phase == "decode":
871
+ sidx = cache_pos_for_partitions[batch_idx][0]
872
+ attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_attn_decode(
873
+ query_state,
874
+ key_state,
875
+ value_state,
876
+ attn_mask,
877
+ past_key_state.unsqueeze(2),
878
+ past_value_state.unsqueeze(2),
879
+ sidx,
880
+ self.kvcache_partition_size,
881
+ )
882
+ else:
883
+ sidx = cache_pos_for_partitions[0][0]
884
+ attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_attn_prefill(
885
+ query_state,
886
+ key_state,
887
+ value_state,
888
+ attn_mask,
889
+ past_key_state.unsqueeze(2),
890
+ past_value_state.unsqueeze(2),
891
+ batch_idx,
892
+ sidx,
893
+ self.kvcache_partition_size,
894
+ )
895
+
896
+ # reshape for removing repeat_kv
897
+ attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim)
898
+ attn_output = attn_output.transpose(1, 2).contiguous()
899
+ attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
900
+
901
+ return attn_output, key_state, value_state
902
+
903
+ def forward(
904
+ self,
905
+ hidden_states: torch.Tensor,
906
+ attention_mask: torch.Tensor,
907
+ current_steps: torch.LongTensor,
908
+ batch_position: torch.Tensor,
909
+ past_key_values: Tuple[Tuple[torch.Tensor]],
910
+ cos: Optional[torch.Tensor] = None,
911
+ sin: Optional[torch.Tensor] = None,
912
+ ):
913
+ batch_size, query_length, _ = hidden_states.size()
914
+
915
+ query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
916
+
917
+ query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
918
+ key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
919
+ value_states = value_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(
920
+ 1, 2
921
+ )
922
+ # b, num_head, query, head_dim
923
+
924
+ max_seq_len = past_key_values[self.layer_idx][0].shape[-2]
925
+
926
+ if cos is not None and sin is not None:
927
+ query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
928
+
929
+ cache_pos_for_partitions = self.get_cache_pos_for_partitions(
930
+ current_steps, batch_size=batch_size, max_seq_len=max_seq_len
931
+ ) # batch_size, num_partitions
932
+
933
+ _key_states = []
934
+ _value_states = []
935
+ _attn_outputs = []
936
+ for b in range(batch_size):
937
+ attn_output, key_state, value_state = self.rbln_flash_attention(
938
+ query_states[b].unsqueeze(0),
939
+ key_states[b].unsqueeze(0),
940
+ value_states[b].unsqueeze(0),
941
+ attention_mask[b].unsqueeze(0)
942
+ if self.phase == "decode"
943
+ else attention_mask, # TODO(jongho): fix when msoftmax is supported
944
+ past_key_state=past_key_values[self.layer_idx][0],
945
+ past_value_state=past_key_values[self.layer_idx][1],
946
+ batch_idx=b if self.phase == "decode" else batch_position,
947
+ cache_pos_for_partitions=cache_pos_for_partitions,
948
+ )
949
+ _key_states.append(key_state)
950
+ _value_states.append(value_state)
951
+ _attn_outputs.append(attn_output)
952
+ key_states = torch.cat(_key_states, dim=0)
953
+ value_states = torch.cat(_value_states, dim=0)
954
+ attn_outputs = torch.cat(_attn_outputs, dim=0)
955
+
956
+ attn_outputs = self.o_proj(attn_outputs)
957
+ past_key_values[self.layer_idx] = key_states, value_states
958
+ return attn_outputs, past_key_values