optimum-rbln 0.1.15__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. optimum/rbln/__init__.py +26 -33
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/diffusers/__init__.py +4 -0
  4. optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +66 -24
  5. optimum/rbln/diffusers/models/__init__.py +2 -0
  6. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +38 -12
  7. optimum/rbln/diffusers/models/autoencoders/vae.py +0 -1
  8. optimum/rbln/diffusers/models/controlnet.py +1 -1
  9. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +1 -1
  10. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +5 -7
  11. optimum/rbln/diffusers/pipelines/__init__.py +1 -0
  12. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +8 -7
  13. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
  14. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -2
  15. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +17 -2
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +17 -2
  17. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -2
  18. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -2
  19. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -2
  20. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1 -2
  21. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1 -2
  22. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1 -2
  23. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +23 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +1 -2
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +1 -2
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +1 -2
  27. optimum/rbln/modeling.py +13 -347
  28. optimum/rbln/modeling_base.py +24 -4
  29. optimum/rbln/modeling_config.py +31 -7
  30. optimum/rbln/ops/__init__.py +26 -0
  31. optimum/rbln/ops/attn.py +221 -0
  32. optimum/rbln/ops/flash_attn.py +70 -0
  33. optimum/rbln/ops/kv_cache_update.py +69 -0
  34. optimum/rbln/transformers/__init__.py +20 -0
  35. optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
  36. optimum/rbln/transformers/modeling_generic.py +385 -0
  37. optimum/rbln/transformers/models/auto/__init__.py +23 -0
  38. optimum/rbln/transformers/models/auto/modeling_auto.py +0 -1
  39. optimum/rbln/transformers/models/bart/__init__.py +0 -1
  40. optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
  41. optimum/rbln/transformers/models/bart/modeling_bart.py +8 -4
  42. optimum/rbln/transformers/models/clip/modeling_clip.py +1 -1
  43. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -7
  44. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +329 -328
  45. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +92 -107
  46. optimum/rbln/transformers/models/exaone/exaone_architecture.py +2 -3
  47. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  48. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -10
  49. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
  50. optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
  51. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +1 -0
  52. optimum/rbln/transformers/models/midm/midm_architecture.py +11 -11
  53. optimum/rbln/transformers/models/midm/modeling_midm.py +0 -1
  54. optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
  55. optimum/rbln/transformers/models/phi/phi_architecture.py +2 -3
  56. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
  57. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +57 -57
  58. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
  59. optimum/rbln/transformers/models/t5/__init__.py +0 -1
  60. optimum/rbln/transformers/models/t5/modeling_t5.py +5 -2
  61. optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
  62. optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
  63. optimum/rbln/transformers/models/whisper/modeling_whisper.py +77 -54
  64. optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
  65. optimum/rbln/transformers/utils/rbln_quantization.py +0 -1
  66. optimum/rbln/utils/decorator_utils.py +51 -15
  67. optimum/rbln/utils/import_utils.py +7 -0
  68. optimum/rbln/utils/logging.py +37 -0
  69. optimum/rbln/utils/model_utils.py +0 -1
  70. optimum/rbln/utils/runtime_utils.py +9 -3
  71. optimum/rbln/utils/save_utils.py +17 -0
  72. optimum/rbln/utils/submodule.py +23 -0
  73. {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.0.dist-info}/METADATA +37 -26
  74. {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.0.dist-info}/RECORD +76 -72
  75. optimum_rbln-0.2.0.dist-info/licenses/LICENSE +288 -0
  76. optimum/rbln/transformers/cache_utils.py +0 -107
  77. optimum/rbln/utils/timer_utils.py +0 -43
  78. optimum_rbln-0.1.15.dist-info/licenses/LICENSE +0 -201
  79. {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.0.dist-info}/WHEEL +0 -0
@@ -27,129 +27,82 @@ from typing import List, Optional, Tuple
27
27
  import torch
28
28
  from torch import nn
29
29
  from transformers import PretrainedConfig, PreTrainedModel
30
- from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
31
30
 
31
+ from ....ops import register_rbln_custom_attention, register_rbln_custom_flash_attention
32
32
  from ....utils import logging
33
33
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
34
34
 
35
35
 
36
- if is_torch_greater_or_equal_than_2_4:
37
- register_fake = torch.library.register_fake
38
- else:
39
- register_fake = torch.library.impl_abstract
40
-
41
-
42
36
  logger = logging.get_logger(__name__)
43
- """
44
- ##############################################################################
45
- # RBLN custom operation (python interface)
46
- # torch.compile custom operation
47
- # torch.library.define - kernel declaration
48
- # torch.library.impl - kernel implementation
49
- # torch.library.impl_abstract - symbolic trace
50
- ##############################################################################
51
- """
52
-
53
- # RBLN custom op(flash attention decode)
54
- torch.library.define(
55
- "rbln_custom_ops::flash_attn_decode",
56
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor[]",
57
- )
58
-
59
-
60
- @torch.library.impl("rbln_custom_ops::flash_attn_decode", "cpu")
61
- def flash_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, partition):
62
- """
63
- WORKAROUND:
64
- Partition is declared as an argument to the function, even though it is
65
- not actually used in the CPU implementation, this allows the rbln compiler
66
- to perform flash attention operations with partition as an argument.
67
- """
68
- assert kcache.dim() == k.dim()
69
- assert vcache.dim() == v.dim()
70
- assert k.size(-2) == v.size(-2)
71
- assert partition.dim() == 1
72
- b = 0
73
- if seq.dim() == 1:
74
- s = seq[0]
75
- elif seq.dim() == 0:
76
- s = seq
77
- else:
78
- assert False
79
- e = s + k.size(-2)
80
- updated_k = kcache[b].unsqueeze(0).slice_scatter(k, dim=-2, start=s, end=e)
81
- updated_v = vcache[b].unsqueeze(0).slice_scatter(v, dim=-2, start=s, end=e)
82
- attn_weight = torch.matmul(q, updated_k.transpose(3, 4)) / math.sqrt(128)
83
- attn_weight = attn_weight + mask
84
- attn_weight = nn.functional.softmax(attn_weight, dim=-1, dtype=torch.float32).to(q.dtype)
85
- attn_output = torch.matmul(attn_weight, updated_v)
86
- return attn_output, torch.empty_like(kcache), torch.empty_like(vcache)
87
-
88
-
89
- @register_fake("rbln_custom_ops::flash_attn_decode")
90
- def flash_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, partition):
91
- return torch.empty_like(q), torch.empty_like(kcache), torch.empty_like(vcache)
92
-
93
-
94
- # RBLN custom op(flash attention prefill)
95
- torch.library.define(
96
- "rbln_custom_ops::flash_attn_prefill",
97
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e) -> Tensor[]",
98
- )
99
-
100
-
101
- @torch.library.impl("rbln_custom_ops::flash_attn_prefill", "cpu")
102
- def flash_attn_prefill_cpu(q, k, v, mask, kcache, vcache, batch, seq, partition):
103
- """
104
- WORKAROUND:
105
- Partition is declared as an argument to the function, even though it is
106
- not actually used in the CPU implementation, this allows the rbln compiler
107
- to perform flash attention operations with partition as an argument.
108
- """
109
- assert kcache.dim() == k.dim()
110
- assert vcache.dim() == v.dim()
111
- assert k.size(-2) == v.size(-2)
112
- assert partition.dim() == 1
113
- if batch.dim() == 1:
114
- b = batch[0]
115
- elif batch.dim() == 0:
116
- b = batch
117
- else:
118
- assert False
119
- if seq.dim() == 1:
120
- s = seq[0]
121
- elif seq.dim() == 0:
122
- s = seq
123
- else:
124
- assert False
125
- e = s + k.size(-2)
126
- updated_k = kcache[b].unsqueeze(0).slice_scatter(k, dim=-2, start=s, end=e)
127
- updated_v = vcache[b].unsqueeze(0).slice_scatter(v, dim=-2, start=s, end=e)
128
- attn_weight = torch.matmul(q, updated_k.transpose(3, 4)) / math.sqrt(128)
129
- attn_weight = attn_weight + mask
130
- attn_weight = nn.functional.softmax(attn_weight, dim=-1, dtype=torch.float32).to(q.dtype)
131
- attn_output = torch.matmul(attn_weight, updated_v)
132
- return attn_output, torch.empty_like(kcache), torch.empty_like(vcache)
133
-
134
-
135
- @register_fake("rbln_custom_ops::flash_attn_prefill")
136
- def flash_attn_prefill_abstract(q, k, v, m, kcache, vcache, batch, seq, partition):
137
- return torch.empty_like(q), torch.empty_like(kcache), torch.empty_like(vcache)
138
37
 
38
+ DEFAULT_FLASH_ATTN_PARTITION_LENGTH = 16_384
39
+ DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH = 32_768
40
+ MIN_FLASH_ATTN_MAX_SEQ_LEN = 8_192
41
+ MIN_FLASH_ATTN_PARTITION_LENGTH = 4_096
42
+ MAX_FLASH_ATTN_PARTITION_LENGTH = 32_768
139
43
 
140
- # RBLN custom op(cache update)
141
- torch.library.define("rbln_custom_ops::rbln_cache_update", "(Tensor x, Tensor y, Tensor z, Tensor w) -> Tensor")
142
44
 
45
+ def validate_attention_method(
46
+ rbln_attn_impl: str, rbln_kvcache_partition_len: int, rbln_max_seq_len: int
47
+ ) -> Tuple[str, int]:
48
+ if rbln_kvcache_partition_len is not None:
49
+ if rbln_attn_impl == "eager":
50
+ raise ValueError(
51
+ f"`rbln_kvcache_partition_len` is set to {rbln_kvcache_partition_len}, but KV cache partitioning"
52
+ " is not supported with 'eager' attention. Please set `rbln_kvcache_partition_len` to None, "
53
+ "or switch `rbln_attn_impl` to 'flash_attn' to use KV cache partitioning."
54
+ )
55
+ elif rbln_attn_impl is None:
56
+ rbln_attn_impl = "flash_attn"
57
+ logger.warning(
58
+ "A non-null `rbln_kvcache_partition_len` was provided, but `rbln_attn_impl` was not explicitly set. "
59
+ "Since KV cache partitioning is only supported with flash attention, "
60
+ "`rbln_attn_impl` has been automatically switched to 'flash_attn'."
61
+ )
143
62
 
144
- @torch.library.impl("rbln_custom_ops::rbln_cache_update", "cpu")
145
- def rbln_cache_update_cpu(cache, value, batch, seq):
146
- updated_cache = cache[batch].slice_scatter(value, dim=-2, start=batch[0], end=batch[0] + seq[0])
147
- return updated_cache
63
+ rbln_attn_impl = "eager" if rbln_attn_impl is None else rbln_attn_impl
64
+ if rbln_attn_impl not in ["eager", "flash_attn"]:
65
+ raise ValueError(f"Unknown `rbln_attn_impl` : {rbln_attn_impl}. (Available : 'eager', 'flash_attn`)")
66
+
67
+ if rbln_kvcache_partition_len is None and rbln_attn_impl == "flash_attn":
68
+ rbln_kvcache_partition_len = DEFAULT_FLASH_ATTN_PARTITION_LENGTH
69
+
70
+ ## Checking Constraints...
71
+ # Constraint of eager attention:
72
+ # - `max_seq_len` <= 32k
73
+
74
+ # Constraints of flash attention:
75
+ # 1. `max_seq_len` should be multiple of `partition_len`.
76
+ # 2. 4k <= `partition_len` <= 32k.
77
+ # 3. `max_seq_len` should be larger then 8k.
78
+ if rbln_attn_impl == "eager" and rbln_max_seq_len > DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH:
79
+ raise ValueError(
80
+ f"`rbln_max_seq_len` is set to {rbln_max_seq_len}, "
81
+ f"which exceeds the limit of {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} for 'eager' attention. "
82
+ f"Please reduce the `rbln_max_seq_len` to {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} or lower,"
83
+ " or consider switching `rbln_attn_impl` to 'flash_attn' for larger sequence lengths."
84
+ )
148
85
 
86
+ if rbln_attn_impl == "flash_attn":
87
+ if rbln_max_seq_len // rbln_kvcache_partition_len < 2 or rbln_max_seq_len % rbln_kvcache_partition_len != 0:
88
+ raise ValueError(
89
+ f"`rbln_max_seq_len` ({rbln_max_seq_len}) must be a multiple of `rbln_kvcache_partition_len` ({rbln_kvcache_partition_len}) "
90
+ f"when using 'flash_attn'. Please adjust either value to meet this requirement."
91
+ )
92
+ elif not (MIN_FLASH_ATTN_PARTITION_LENGTH <= rbln_kvcache_partition_len <= MAX_FLASH_ATTN_PARTITION_LENGTH):
93
+ raise ValueError(
94
+ f"`rbln_kvcache_partition_len` ({rbln_kvcache_partition_len}) is out of the supported range for 'flash_attn' "
95
+ f"({MIN_FLASH_ATTN_PARTITION_LENGTH} <= `rbln_kvcache_partition_len` <= {MAX_FLASH_ATTN_PARTITION_LENGTH}). "
96
+ f"Please provide a valid value within this range."
97
+ )
98
+ elif rbln_max_seq_len < MIN_FLASH_ATTN_MAX_SEQ_LEN:
99
+ raise ValueError(
100
+ f"`rbln_max_seq_len` ({rbln_max_seq_len}) is too small for 'flash_attn'. The minimum "
101
+ f"supported value is {MIN_FLASH_ATTN_MAX_SEQ_LEN}. Please increase `rbln_max_seq_len` to meet "
102
+ "this requirement, or consider switching `rbln_attn_impl` to 'eager' for shorter lengths."
103
+ )
149
104
 
150
- @register_fake("rbln_custom_ops::rbln_cache_update")
151
- def rbln_cache_update_abstract(cache, value, batch, seq):
152
- return torch.empty_like(cache)
105
+ return rbln_attn_impl, rbln_kvcache_partition_len
153
106
 
154
107
 
155
108
  class DecoderOnlyWrapper(nn.Module):
@@ -169,11 +122,23 @@ class DecoderOnlyWrapper(nn.Module):
169
122
  causal_lm (PreTrainedModel): The Huggingface causal language model to wrap
170
123
  max_seq_len (int): Maximum sequence length for position embeddings and cache sizes
171
124
  use_rotary_emb (bool): Whether to use rotary position embeddings
125
+ attn_impl (str): The attention implementation to use.
126
+ - "eager": Uses the standard attention.
127
+ - "flash_attn": Uses flash attention. When set,
128
+ the key/value cache is partitioned into chunks of length
129
+ `kvcache_partition_len`.
172
130
  kvcache_partition_len (Optional[int]): Length of KV cache partitions for flash attention.
173
- If provided, uses flash attention; if None, uses standard attention
131
+ This is only relevant if `attn_impl` is set to "flash_attn`
174
132
  """
175
133
 
176
- def __init__(self, causal_lm: PreTrainedModel, max_seq_len, use_rotary_emb: bool, kvcache_partition_len=None):
134
+ def __init__(
135
+ self,
136
+ causal_lm: PreTrainedModel,
137
+ max_seq_len: int,
138
+ use_rotary_emb: bool,
139
+ attn_impl: str,
140
+ kvcache_partition_len: Optional[int] = None,
141
+ ):
177
142
  super().__init__()
178
143
  self.config = causal_lm.config
179
144
 
@@ -182,14 +147,21 @@ class DecoderOnlyWrapper(nn.Module):
182
147
  else:
183
148
  self.rotary_emb = None
184
149
 
185
- if kvcache_partition_len is not None:
186
- # WORKAROUND : for passing partition length as a value to the rbln compiler.
187
- # What is actually used is the shape of this tensor.
188
- self.attn_impl = "flash_attn"
189
- logger.info(f"Using flash-attention. (partition length : {kvcache_partition_len})")
150
+ self.attn_impl = attn_impl
151
+ if self.attn_impl == "flash_attn":
152
+ self.kvcache_partition_len = kvcache_partition_len or DEFAULT_FLASH_ATTN_PARTITION_LENGTH
153
+ register_rbln_custom_flash_attention()
154
+ elif self.attn_impl == "eager":
155
+ self.kvcache_partition_len = None
156
+ register_rbln_custom_attention()
190
157
  else:
191
- self.attn_impl = "eager"
192
- self.kvcache_partition_len = kvcache_partition_len
158
+ raise ValueError(f"Unknown attn_impl : {self.attn_impl}")
159
+
160
+ if kvcache_partition_len and kvcache_partition_len > max_seq_len:
161
+ raise ValueError(
162
+ f"kvcache_partition_len({kvcache_partition_len}) should be lower"
163
+ f" or equal to max_seq_len({max_seq_len})!"
164
+ )
193
165
 
194
166
  self.causal_lm = self.convert_to_rbln_causal_lm(causal_lm)
195
167
 
@@ -213,12 +185,12 @@ class DecoderOnlyWrapper(nn.Module):
213
185
 
214
186
  new_layer = DecoderOnlyLayer(layer, new_self_attn)
215
187
  new_layers.append(new_layer)
216
- new_model = DecoderOnlyModel(causal_lm.model, new_layers)
188
+ new_model = DecoderOnlyModel(causal_lm.model, new_layers, partition_len=self.kvcache_partition_len)
217
189
  new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
218
190
  return new_causal_lm
219
191
 
220
192
  @property
221
- def phase(self):
193
+ def phase(self) -> str:
222
194
  return self._phase
223
195
 
224
196
  @phase.setter
@@ -226,21 +198,32 @@ class DecoderOnlyWrapper(nn.Module):
226
198
  self._phase = phase
227
199
  self.causal_lm.phase = phase
228
200
 
229
- def forward(
230
- self,
231
- input_ids_or_inputs_embeds,
232
- attention_mask,
233
- cache_position,
234
- batch_position,
235
- query_position,
236
- *past_key_values,
237
- ):
201
+ def forward(self, *args):
202
+ if self.phase == "decode":
203
+ (
204
+ input_ids_or_inputs_embeds,
205
+ attention_mask,
206
+ cache_position,
207
+ *past_key_values,
208
+ ) = args
209
+ batch_position = torch.tensor(0, dtype=torch.int16)
210
+ query_position = None
211
+ elif self.phase == "prefill":
212
+ (
213
+ input_ids_or_inputs_embeds,
214
+ attention_mask,
215
+ cache_position,
216
+ batch_position,
217
+ query_position,
218
+ *past_key_values,
219
+ ) = args
220
+ else:
221
+ raise ValueError(f"Unknown phase: {self.phase}")
222
+
238
223
  if input_ids_or_inputs_embeds.ndim == 2:
239
- # It is input_ids
240
224
  input_ids = input_ids_or_inputs_embeds
241
225
  inputs_embeds = None
242
226
  elif input_ids_or_inputs_embeds.ndim == 3:
243
- # It is inputs_embeds
244
227
  input_ids = None
245
228
  inputs_embeds = input_ids_or_inputs_embeds
246
229
  else:
@@ -248,15 +231,9 @@ class DecoderOnlyWrapper(nn.Module):
248
231
 
249
232
  if len(past_key_values) != 2 * self.num_hidden_layers:
250
233
  raise ValueError(
251
- f"Different past_key_values to model's config. {len(past_key_values)} != {self.num_hidden_layers}"
234
+ f"Different past_key_values to model's config. {len(past_key_values)} != {2 * self.num_hidden_layers}"
252
235
  )
253
236
 
254
- seq_len = input_ids_or_inputs_embeds.shape[1]
255
- if seq_len == 1:
256
- self.phase = "decode"
257
- else:
258
- self.phase = "prefill"
259
-
260
237
  # [key, value] * n_layer -> ( (key, value) ) * n_layer
261
238
  # cache shape : batch, n_heads, 1, max_seq_len, head_dim
262
239
  _past_key_values = []
@@ -286,8 +263,7 @@ class DecoderOnlyWrapper(nn.Module):
286
263
  _present_key_values = _present_key_values + (key_states, value_states)
287
264
  present_key_values = _present_key_values
288
265
 
289
- # batch_position + query_position is dummy output node to keep the number of outputs
290
- return logit, present_key_values, batch_position + query_position
266
+ return logit, present_key_values
291
267
 
292
268
 
293
269
  class DecoderOnlyForCausalLM(nn.Module):
@@ -371,13 +347,12 @@ class DecoderOnlyModel(nn.Module):
371
347
  _phase: Current processing phase ("prefill" or "decode")
372
348
  """
373
349
 
374
- mask_fmin = torch.finfo(torch.float16).min
375
-
376
- def __init__(self, model, layers: List["DecoderOnlyLayer"]):
350
+ def __init__(self, model, layers: List["DecoderOnlyLayer"], partition_len=None):
377
351
  super().__init__()
378
352
  self._original_mod = model
379
353
  self.layers = nn.ModuleList(layers)
380
354
  self._phase = "prefill"
355
+ self.partition_len = partition_len
381
356
 
382
357
  @property
383
358
  def phase(self):
@@ -389,10 +364,26 @@ class DecoderOnlyModel(nn.Module):
389
364
  for layer in self.layers:
390
365
  layer.phase = phase
391
366
 
367
+ @property
368
+ def attn_impl(self) -> str:
369
+ return "eager" if self.partition_len is None else "flash_attn"
370
+
392
371
  @property
393
372
  def hidden_multiplier(self):
394
373
  return 1
395
374
 
375
+ def convert_sequence_positions_for_flash_attn(self, seq_positions, max_seq_len):
376
+ if self.attn_impl != "flash_attn":
377
+ raise NotImplementedError(f"Unknown attn_impl ({self.attn_impl}).")
378
+
379
+ partition_len = self.partition_len
380
+ num_partition = max_seq_len // partition_len
381
+
382
+ cs = seq_positions.repeat(num_partition, 1).transpose(0, 1)
383
+ pidx = torch.arange(num_partition)
384
+ cache_pos_for_partitions = torch.clamp(cs - pidx * partition_len, 0, partition_len)
385
+ return cache_pos_for_partitions
386
+
396
387
  def get_last_layernorm(self) -> nn.LayerNorm:
397
388
  return self._original_mod.norm
398
389
 
@@ -425,7 +416,6 @@ class DecoderOnlyModel(nn.Module):
425
416
  inputs_embeds = self.get_embedding()(input_ids)
426
417
 
427
418
  hidden_states = inputs_embeds * self.hidden_multiplier
428
- attention_mask = (1 - attention_mask) * self.mask_fmin
429
419
 
430
420
  # get cos,sin vector if needed
431
421
  if rotary_emb is not None:
@@ -446,14 +436,19 @@ class DecoderOnlyModel(nn.Module):
446
436
  cos, sin = None, None
447
437
 
448
438
  # (batch, seq_len) -> (batch,)
449
- current_steps = cache_position[:, 0]
439
+ seq_positions = cache_position[:, 0]
440
+ if self.attn_impl == "flash_attn":
441
+ max_seq_len = past_key_values[0][0].shape[-2]
442
+ seq_positions = self.convert_sequence_positions_for_flash_attn(
443
+ seq_positions=seq_positions, max_seq_len=max_seq_len
444
+ )
450
445
 
451
446
  present_key_values = past_key_values
452
447
  for layer in self.layers:
453
448
  hidden_states, present_key_values = layer(
454
449
  hidden_states=hidden_states,
455
450
  attention_mask=attention_mask,
456
- current_steps=current_steps,
451
+ seq_positions=seq_positions,
457
452
  batch_position=batch_position,
458
453
  past_key_values=present_key_values,
459
454
  cos=cos,
@@ -514,20 +509,19 @@ class DecoderOnlyLayer(nn.Module):
514
509
  self,
515
510
  hidden_states: torch.Tensor,
516
511
  attention_mask: torch.Tensor,
517
- current_steps: torch.LongTensor,
512
+ seq_positions: torch.LongTensor,
518
513
  batch_position: torch.Tensor,
519
514
  past_key_values: Tuple[Tuple[torch.Tensor]],
520
515
  cos: Optional[torch.Tensor] = None,
521
516
  sin: Optional[torch.Tensor] = None,
522
517
  ):
523
518
  residual = hidden_states
524
-
525
519
  hidden_states = self.get_pre_attention_layernorm()(hidden_states)
526
520
 
527
521
  hidden_states, present_key_values = self.self_attn(
528
522
  hidden_states=hidden_states,
529
523
  attention_mask=attention_mask,
530
- current_steps=current_steps,
524
+ seq_positions=seq_positions,
531
525
  batch_position=batch_position,
532
526
  past_key_values=past_key_values,
533
527
  cos=cos,
@@ -561,15 +555,34 @@ class DecoderOnlyAttention(nn.Module):
561
555
  self.layer_idx = self_attn.layer_idx
562
556
  self.num_heads = self._original_mod.num_heads
563
557
  self.head_dim = self._original_mod.head_dim
564
- self.phase = "prefill"
558
+ self._phase = "prefill"
559
+ self.scale = torch.tensor(self.get_attn_scale())
560
+
561
+ if hasattr(self._original_mod, "num_key_value_heads"):
562
+ self.num_key_value_heads = self._original_mod.num_key_value_heads
563
+ else:
564
+ self.num_key_value_heads = self._original_mod.num_heads
565
+
566
+ self.attention = self.get_attention()
565
567
  self.__post_init__()
566
568
 
569
+ @property
570
+ def phase(self):
571
+ return self._phase
572
+
573
+ @phase.setter
574
+ def phase(self, phase: str):
575
+ self._phase = phase
576
+ self.attention.phase = phase
577
+
578
+ def get_attention(self):
579
+ return AttentionOp(self.num_heads, self.head_dim, self.num_key_value_heads)
580
+
567
581
  def __post_init__(self):
568
582
  self.q_proj = self._original_mod.q_proj
569
583
  self.k_proj = self._original_mod.k_proj
570
584
  self.v_proj = self._original_mod.v_proj
571
585
  self.o_proj = self._original_mod.o_proj
572
- self.num_key_value_heads = self._original_mod.num_key_value_heads
573
586
 
574
587
  def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
575
588
  """Projects input hidden states into query, key, and value representations.
@@ -588,97 +601,17 @@ class DecoderOnlyAttention(nn.Module):
588
601
  def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
589
602
  return apply_rotary_pos_emb(query_states, key_states, cos, sin)
590
603
 
591
- def rbln_attention(
592
- self,
593
- query_state,
594
- key_state,
595
- value_state,
596
- attn_mask,
597
- batch_idx,
598
- past_key_state,
599
- past_value_state,
600
- current_step,
601
- # below are designed for Midm, GPT which requires to support scaling for attention weights
602
- # TODO(jongho): Merge and manage scales generally
603
- layer_idx=None,
604
- scale_attn_weights: bool = None,
605
- scale_attn_by_inverse_layer_idx: bool = None,
606
- scale_qk_by_inverse_layer_idx: bool = None,
607
- ):
608
- """Compute attention with static shapes and explicit cache management.
609
-
610
- Args:
611
- query_state: Query tensor [1, num_heads, 1, head_dim]
612
- key_state: Key tensor [1, num_heads, seq_len, head_dim]
613
- value_state: Value tensor [1, num_heads, seq_len, head_dim]
614
- attn_mask: Attention mask tensor
615
- batch_idx: Batch index for cache lookup
616
- past_key_state: Previous key cache states
617
- past_value_state: Previous value cache states
618
- current_step: Current position in sequence
619
-
620
- Returns:
621
- Tuple of (attention_output, key_state, value_state)
622
- """
623
- # Implementation details.
624
- # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
625
- key_state = key_state.unsqueeze(2) # 1, 32, 1, 128, 128
626
- value_state = value_state.unsqueeze(2)
627
- attn_mask = attn_mask.unsqueeze(2)
628
-
629
- query_state = query_state.view(
630
- 1,
631
- self.num_key_value_heads,
632
- self.num_heads // self.num_key_value_heads,
633
- -1, # seq len
634
- self.head_dim,
635
- ) #
636
-
637
- kend = current_step + key_state.shape[-2]
638
- vend = current_step + value_state.shape[-2]
639
-
640
- key_state = (
641
- past_key_state[batch_idx]
642
- .unsqueeze(0)
643
- .unsqueeze(2)
644
- .slice_scatter(key_state, dim=-2, start=current_step, end=kend)
645
- )
646
- value_state = (
647
- past_value_state[batch_idx]
648
- .unsqueeze(0)
649
- .unsqueeze(2)
650
- .slice_scatter(value_state, dim=-2, start=current_step, end=vend)
651
- )
652
-
653
- attn_weight = torch.matmul(query_state, key_state.transpose(3, 4))
654
- attn_weight = attn_weight / math.sqrt(self.head_dim)
655
-
656
- if layer_idx is not None and (scale_attn_by_inverse_layer_idx or scale_qk_by_inverse_layer_idx):
657
- attn_weight = attn_weight / float(layer_idx + 1)
658
-
659
- attn_weight += attn_mask
660
-
661
- if layer_idx is not None and scale_qk_by_inverse_layer_idx:
662
- attn_weight = attn_weight * float(layer_idx + 1)
663
-
664
- attn_weight = nn.functional.softmax(attn_weight, dim=-1)
665
-
666
- attn_output = torch.matmul(attn_weight, value_state)
667
-
668
- attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim)
669
- attn_output = attn_output.transpose(1, 2).contiguous()
670
- attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
671
-
672
- return attn_output, key_state, value_state
604
+ def get_attn_scale(self):
605
+ return 1 / math.sqrt(self.head_dim)
673
606
 
674
607
  def forward(
675
608
  self,
676
609
  hidden_states: torch.Tensor,
677
610
  attention_mask: torch.Tensor,
678
- current_steps: torch.LongTensor,
611
+ seq_positions: torch.LongTensor,
679
612
  batch_position: torch.Tensor,
680
613
  past_key_values: Tuple[Tuple[torch.Tensor]],
681
- cos: Optional[torch.Tensor] = None, # (batch, 1, prefill_size, head_dim)
614
+ cos: Optional[torch.Tensor] = None,
682
615
  sin: Optional[torch.Tensor] = None,
683
616
  ):
684
617
  batch_size, query_length, _ = hidden_states.size()
@@ -698,22 +631,24 @@ class DecoderOnlyAttention(nn.Module):
698
631
  if batch_size > 1 and self.phase == "prefill":
699
632
  raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
700
633
 
634
+ # TODO(jongho): flash attn legacy. (clone)
635
+ _seq_positions = seq_positions.clone().unsqueeze(1)
636
+
701
637
  _key_states = []
702
638
  _value_states = []
703
639
  _attn_outputs = []
704
640
  for b in range(batch_size):
705
- current_step = current_steps[b]
706
- attn_output, key_state, value_state = self.rbln_attention(
641
+ seq_position = _seq_positions[b][0]
642
+ attn_output, key_state, value_state = self.attention(
707
643
  query_states[b].unsqueeze(0),
708
644
  key_states[b].unsqueeze(0),
709
645
  value_states[b].unsqueeze(0),
710
- attention_mask[b].unsqueeze(0)
711
- if self.phase == "decode"
712
- else attention_mask, # TODO(jongho): fix when msoftmax is supported
646
+ attention_mask[b].unsqueeze(0) if self.phase == "decode" else attention_mask,
713
647
  past_key_state=past_key_values[self.layer_idx][0],
714
648
  past_value_state=past_key_values[self.layer_idx][1],
715
- batch_idx=b if self.phase == "decode" else batch_position,
716
- current_step=current_step,
649
+ batch_position=b if self.phase == "decode" else batch_position,
650
+ seq_position=seq_position,
651
+ scale=self.scale,
717
652
  )
718
653
  _key_states.append(key_state)
719
654
  _value_states.append(value_state)
@@ -727,6 +662,87 @@ class DecoderOnlyAttention(nn.Module):
727
662
  return attn_outputs, past_key_values
728
663
 
729
664
 
665
+ class AttentionOp(nn.Module):
666
+ def __init__(self, num_heads: int, head_dim: int, num_key_value_heads: int):
667
+ super().__init__()
668
+ self.num_heads = num_heads
669
+ self.head_dim = head_dim
670
+ self.num_key_value_heads = num_key_value_heads
671
+ self.phase = "prefill"
672
+
673
+ def forward(
674
+ self,
675
+ query_state: torch.Tensor,
676
+ key_state: torch.Tensor,
677
+ value_state: torch.Tensor,
678
+ attn_mask: torch.Tensor,
679
+ batch_position: torch.Tensor,
680
+ past_key_state: torch.Tensor,
681
+ past_value_state: torch.Tensor,
682
+ seq_position: torch.Tensor,
683
+ scale: torch.Tensor,
684
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
685
+ """Compute attention with static shapes and explicit cache management.
686
+
687
+ Args:
688
+ query_state: Query tensor [1, num_heads, 1, head_dim]
689
+ key_state: Key tensor [1, num_heads, seq_len, head_dim]
690
+ value_state: Value tensor [1, num_heads, seq_len, head_dim]
691
+ attn_mask: Attention mask tensor ∈ {0, 1}
692
+ batch_position: Batch index for cache lookup
693
+ past_key_state: Previous key cache states
694
+ past_value_state: Previous value cache states
695
+ seq_position: Current position in sequence
696
+ scale: Scale applied to attn weights
697
+
698
+ Returns:
699
+ Tuple of (attention_output, key_state, value_state)
700
+ """
701
+ # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
702
+ key_state = key_state.unsqueeze(2) # 1, 32, 1, 128, 128
703
+ value_state = value_state.unsqueeze(2)
704
+ attn_mask = attn_mask.unsqueeze(2)
705
+
706
+ query_state = query_state.view(
707
+ 1,
708
+ self.num_key_value_heads,
709
+ self.num_heads // self.num_key_value_heads,
710
+ -1, # seq len
711
+ self.head_dim,
712
+ )
713
+
714
+ if self.phase == "decode":
715
+ attn_output, key_state, value_state = torch.ops.rbln_custom_ops.attn_decode(
716
+ query_state,
717
+ key_state,
718
+ value_state,
719
+ attn_mask,
720
+ past_key_state.unsqueeze(2),
721
+ past_value_state.unsqueeze(2),
722
+ seq_position,
723
+ scale,
724
+ )
725
+
726
+ else:
727
+ attn_output, key_state, value_state = torch.ops.rbln_custom_ops.attn_prefill(
728
+ query_state,
729
+ key_state,
730
+ value_state,
731
+ attn_mask,
732
+ past_key_state.unsqueeze(2),
733
+ past_value_state.unsqueeze(2),
734
+ batch_position,
735
+ seq_position,
736
+ scale,
737
+ )
738
+
739
+ attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim)
740
+ attn_output = attn_output.transpose(1, 2).contiguous()
741
+ attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
742
+
743
+ return attn_output, key_state.squeeze(2), value_state.squeeze(2)
744
+
745
+
730
746
  def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
731
747
  """Slice cos[cache_position], sin[cache_position] vector for the query."""
732
748
  if cache_position.shape[0] > 1:
@@ -821,40 +837,83 @@ class RotaryEmbedding(nn.Module):
821
837
 
822
838
  class DecoderOnlyFlashAttention(DecoderOnlyAttention):
823
839
  def __init__(self, self_attn, kvcache_partition_len):
840
+ self.kvcache_partition_size = kvcache_partition_len
824
841
  super().__init__(self_attn=self_attn)
825
- self.kvcache_partition_size = torch.zeros(kvcache_partition_len, dtype=torch.int32)
826
842
 
827
- def get_cache_pos_for_partitions(self, current_steps, batch_size, max_seq_len):
828
- partition_len = self.kvcache_partition_size.size()[0]
829
- num_partition = max_seq_len // partition_len
830
- cache_pos_for_partitions = torch.zeros((batch_size, num_partition), dtype=torch.int32)
831
- if self.phase == "decode":
832
- for b_idx in range(batch_size):
833
- cache_pos = current_steps[b_idx]
834
- for p_idx in range(num_partition):
835
- cache_pos_for_partitions[b_idx][p_idx] = torch.clamp(
836
- cache_pos - partition_len * p_idx, 0, partition_len
837
- )
838
- else: # prefill
839
- cache_pos = current_steps[0]
840
- for p_idx in range(num_partition):
841
- cache_pos_for_partitions[0][p_idx] = torch.clamp(cache_pos - partition_len * p_idx, 0, partition_len)
843
+ def get_attention(self):
844
+ return FlashAttentionOp(self.num_heads, self.head_dim, self.num_key_value_heads, self.kvcache_partition_size)
845
+
846
+ def forward(
847
+ self,
848
+ hidden_states: torch.Tensor,
849
+ attention_mask: torch.Tensor,
850
+ seq_positions: torch.LongTensor,
851
+ batch_position: torch.Tensor,
852
+ past_key_values: Tuple[Tuple[torch.Tensor]],
853
+ cos: Optional[torch.Tensor] = None,
854
+ sin: Optional[torch.Tensor] = None,
855
+ ):
856
+ batch_size, query_length, _ = hidden_states.size()
857
+
858
+ query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
859
+
860
+ query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
861
+ key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
862
+ value_states = value_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(
863
+ 1, 2
864
+ )
865
+ # b, num_head, query, head_dim
866
+
867
+ if cos is not None and sin is not None:
868
+ query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
869
+
870
+ _key_states = []
871
+ _value_states = []
872
+ _attn_outputs = []
873
+ for b in range(batch_size):
874
+ seq_position = seq_positions[b][0] # FIXME: Remove take-take pattern matching
875
+ attn_output, key_state, value_state = self.attention(
876
+ query_states[b].unsqueeze(0),
877
+ key_states[b].unsqueeze(0),
878
+ value_states[b].unsqueeze(0),
879
+ attention_mask[b].unsqueeze(0) if self.phase == "decode" else attention_mask,
880
+ past_key_state=past_key_values[self.layer_idx][0],
881
+ past_value_state=past_key_values[self.layer_idx][1],
882
+ batch_position=b if self.phase == "decode" else batch_position,
883
+ seq_position=seq_position,
884
+ scale=self.scale,
885
+ )
886
+ _key_states.append(key_state)
887
+ _value_states.append(value_state)
888
+ _attn_outputs.append(attn_output)
889
+ key_states = torch.cat(_key_states, dim=0)
890
+ value_states = torch.cat(_value_states, dim=0)
891
+ attn_outputs = torch.cat(_attn_outputs, dim=0)
892
+
893
+ attn_outputs = self.o_proj(attn_outputs)
894
+ past_key_values[self.layer_idx] = key_states, value_states
895
+ return attn_outputs, past_key_values
842
896
 
843
- return cache_pos_for_partitions
844
897
 
845
- def rbln_flash_attention(
898
+ class FlashAttentionOp(AttentionOp):
899
+ def __init__(self, num_heads: int, head_dim: int, num_key_value_heads: int, kvcache_partition_len: int):
900
+ super().__init__(num_heads=num_heads, head_dim=head_dim, num_key_value_heads=num_key_value_heads)
901
+ self.kvcache_partition_size = kvcache_partition_len
902
+
903
+ def forward(
846
904
  self,
847
905
  query_state,
848
906
  key_state,
849
907
  value_state,
850
908
  attn_mask,
851
- batch_idx,
909
+ batch_position,
852
910
  past_key_state,
853
911
  past_value_state,
854
- cache_pos_for_partitions,
912
+ seq_position,
913
+ scale,
855
914
  ):
856
915
  # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
857
- key_state = key_state.unsqueeze(2) # 1, 32, 1, 128, 128
916
+ key_state = key_state.unsqueeze(2)
858
917
  value_state = value_state.unsqueeze(2)
859
918
  attn_mask = attn_mask.unsqueeze(2)
860
919
 
@@ -866,9 +925,7 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
866
925
  self.head_dim,
867
926
  )
868
927
 
869
- # RBLN custom flash attention(decode), dummy batch index
870
928
  if self.phase == "decode":
871
- sidx = cache_pos_for_partitions[batch_idx][0]
872
929
  attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_attn_decode(
873
930
  query_state,
874
931
  key_state,
@@ -876,11 +933,11 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
876
933
  attn_mask,
877
934
  past_key_state.unsqueeze(2),
878
935
  past_value_state.unsqueeze(2),
879
- sidx,
936
+ seq_position,
937
+ scale,
880
938
  self.kvcache_partition_size,
881
939
  )
882
940
  else:
883
- sidx = cache_pos_for_partitions[0][0]
884
941
  attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_attn_prefill(
885
942
  query_state,
886
943
  key_state,
@@ -888,8 +945,9 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
888
945
  attn_mask,
889
946
  past_key_state.unsqueeze(2),
890
947
  past_value_state.unsqueeze(2),
891
- batch_idx,
892
- sidx,
948
+ batch_position,
949
+ seq_position,
950
+ scale,
893
951
  self.kvcache_partition_size,
894
952
  )
895
953
 
@@ -899,60 +957,3 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
899
957
  attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
900
958
 
901
959
  return attn_output, key_state, value_state
902
-
903
- def forward(
904
- self,
905
- hidden_states: torch.Tensor,
906
- attention_mask: torch.Tensor,
907
- current_steps: torch.LongTensor,
908
- batch_position: torch.Tensor,
909
- past_key_values: Tuple[Tuple[torch.Tensor]],
910
- cos: Optional[torch.Tensor] = None,
911
- sin: Optional[torch.Tensor] = None,
912
- ):
913
- batch_size, query_length, _ = hidden_states.size()
914
-
915
- query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
916
-
917
- query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
918
- key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
919
- value_states = value_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(
920
- 1, 2
921
- )
922
- # b, num_head, query, head_dim
923
-
924
- max_seq_len = past_key_values[self.layer_idx][0].shape[-2]
925
-
926
- if cos is not None and sin is not None:
927
- query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
928
-
929
- cache_pos_for_partitions = self.get_cache_pos_for_partitions(
930
- current_steps, batch_size=batch_size, max_seq_len=max_seq_len
931
- ) # batch_size, num_partitions
932
-
933
- _key_states = []
934
- _value_states = []
935
- _attn_outputs = []
936
- for b in range(batch_size):
937
- attn_output, key_state, value_state = self.rbln_flash_attention(
938
- query_states[b].unsqueeze(0),
939
- key_states[b].unsqueeze(0),
940
- value_states[b].unsqueeze(0),
941
- attention_mask[b].unsqueeze(0)
942
- if self.phase == "decode"
943
- else attention_mask, # TODO(jongho): fix when msoftmax is supported
944
- past_key_state=past_key_values[self.layer_idx][0],
945
- past_value_state=past_key_values[self.layer_idx][1],
946
- batch_idx=b if self.phase == "decode" else batch_position,
947
- cache_pos_for_partitions=cache_pos_for_partitions,
948
- )
949
- _key_states.append(key_state)
950
- _value_states.append(value_state)
951
- _attn_outputs.append(attn_output)
952
- key_states = torch.cat(_key_states, dim=0)
953
- value_states = torch.cat(_value_states, dim=0)
954
- attn_outputs = torch.cat(_attn_outputs, dim=0)
955
-
956
- attn_outputs = self.o_proj(attn_outputs)
957
- past_key_values[self.layer_idx] = key_states, value_states
958
- return attn_outputs, past_key_values