optimum-rbln 0.1.4__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. optimum/rbln/__init__.py +21 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/__init__.py +0 -1
  4. optimum/rbln/diffusers/models/autoencoder_kl.py +16 -98
  5. optimum/rbln/diffusers/models/controlnet.py +3 -0
  6. optimum/rbln/diffusers/models/unet_2d_condition.py +3 -3
  7. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +22 -146
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +109 -53
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +114 -53
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -0
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +8 -0
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +8 -0
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +9 -0
  16. optimum/rbln/modeling_alias.py +14 -0
  17. optimum/rbln/modeling_base.py +282 -100
  18. optimum/rbln/modeling_seq2seq.py +58 -132
  19. optimum/rbln/transformers/__init__.py +8 -0
  20. optimum/rbln/transformers/cache_utils.py +111 -0
  21. optimum/rbln/transformers/generation/utils.py +0 -2
  22. optimum/rbln/transformers/models/__init__.py +3 -0
  23. optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
  24. optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
  25. optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
  26. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +515 -0
  27. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +349 -0
  28. optimum/rbln/transformers/models/dpt/__init__.py +24 -0
  29. optimum/rbln/transformers/models/dpt/modeling_dpt.py +89 -0
  30. optimum/rbln/transformers/models/gemma/__init__.py +24 -0
  31. optimum/rbln/transformers/models/gemma/gemma_architecture.py +116 -0
  32. optimum/rbln/transformers/models/gemma/modeling_gemma.py +61 -0
  33. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +200 -174
  34. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +57 -293
  35. optimum/rbln/transformers/models/llama/llama_architecture.py +3 -613
  36. optimum/rbln/transformers/models/llama/modeling_llama.py +9 -469
  37. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
  38. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
  39. optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
  40. optimum/rbln/transformers/models/midm/modeling_midm.py +40 -308
  41. optimum/rbln/transformers/models/whisper/modeling_whisper.py +53 -123
  42. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
  43. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  44. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +125 -0
  45. optimum/rbln/utils/__init__.py +1 -1
  46. optimum/rbln/utils/import_utils.py +46 -0
  47. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/METADATA +18 -53
  48. optimum_rbln-0.1.8.dist-info/RECORD +73 -0
  49. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/WHEEL +1 -1
  50. optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -759
  51. optimum_rbln-0.1.4.dist-info/RECORD +0 -63
  52. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/licenses/LICENSE +0 -0
@@ -21,233 +21,259 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- from typing import Optional, Tuple, Union
24
+ from typing import Dict, Optional, Tuple, Union
25
25
 
26
26
  import torch
27
27
  import torch.nn as nn
28
- from transformers.modeling_outputs import (
29
- BaseModelOutputWithPast,
30
- BaseModelOutputWithPastAndCrossAttentions,
31
- )
32
- from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
28
+ from transformers.modeling_outputs import BaseModelOutputWithPast
33
29
 
30
+ from ...cache_utils import RebelDynamicCache_4D
34
31
 
35
- class _GPT2Attention(GPT2Attention):
36
- def _attn(self, query, key, value, attention_mask=None, head_mask=None):
37
- attn_weights = torch.matmul(query, key.transpose(-1, -2))
38
-
39
- if self.scale_attn_weights:
40
- attn_weights = attn_weights / torch.full(
41
- [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
42
- )
43
32
 
44
- # Layer-wise attention scaling
45
- if self.scale_attn_by_inverse_layer_idx:
46
- attn_weights = attn_weights / float(self.layer_idx + 1)
47
-
48
- # -------------------
49
- # Below are deleted since "where" op does not supported on RBLN graph.
50
- # -------------------
51
- # if not self.is_cross_attention:
52
- # # if only "normal" attention layer implements causal mask
53
- # query_length, key_length = query.size(-2), key.size(-2)
54
- # causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
55
- # mask_value = torch.finfo(attn_weights.dtype).min
56
- # # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
57
- # # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
58
- # mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
59
- # attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
33
+ class GPT2LMHeadModelWrapper(torch.nn.Module):
34
+ def __init__(self, model, max_seq_len):
35
+ super().__init__()
36
+ self.model = model.transformer
37
+ self.lm_head = model.lm_head
38
+ self.config = model.config
39
+ self.max_seq_len = max_seq_len
40
+ self.forward_dict = self.get_forward_dict()
41
+
42
+ def get_forward_dict(self):
43
+ forward_dict = {
44
+ "wrapper": _GPT2Model.forward,
45
+ "model": _GPT2Block.forward,
46
+ "decoder_layer": _GPT2Attention.forward,
47
+ }
48
+ return forward_dict
60
49
 
61
- if attention_mask is not None:
62
- # Apply the attention mask
63
- attn_weights = attn_weights + attention_mask
50
+ def forward(
51
+ self,
52
+ input_ids,
53
+ attention_mask,
54
+ cache_position,
55
+ batch_position,
56
+ *past_key_values,
57
+ ):
58
+ if input_ids.shape[1] == 1:
59
+ rbln_batch_position = None
60
+ else:
61
+ rbln_batch_position = batch_position
62
+
63
+ # Formatting list of past_kv to DynamicCache class.
64
+ past_key_value = RebelDynamicCache_4D.from_input_format(
65
+ cache_position,
66
+ self.config.n_layer,
67
+ *past_key_values,
68
+ )
64
69
 
65
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
70
+ outputs = self.forward_dict["wrapper"](
71
+ self.model,
72
+ input_ids=input_ids,
73
+ attention_mask=attention_mask,
74
+ position_ids=cache_position,
75
+ past_key_value=past_key_value,
76
+ batch_ids=rbln_batch_position,
77
+ forward_dict=self.forward_dict,
78
+ # rotary_emb differenct from_llama
79
+ )
66
80
 
67
- # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
68
- attn_weights = attn_weights.type(value.dtype)
69
- # attn_weights = self.attn_dropout(attn_weights)
81
+ hidden_states = outputs[0]
82
+ logits = self.lm_head(hidden_states)
70
83
 
71
- # Mask heads if we want to
72
- if head_mask is not None:
73
- attn_weights = attn_weights * head_mask
84
+ output = (logits,) + outputs[1:]
74
85
 
75
- attn_output = torch.matmul(attn_weights, value)
86
+ return output, batch_position
76
87
 
77
- return attn_output, attn_weights
78
88
 
89
+ class _GPT2Model:
79
90
  def forward(
80
91
  self,
81
- hidden_states: Optional[Tuple[torch.FloatTensor]],
82
- layer_past: Optional[Tuple[torch.Tensor]] = None,
83
- attention_mask: Optional[torch.FloatTensor] = None,
84
- head_mask: Optional[torch.FloatTensor] = None,
85
- encoder_hidden_states: Optional[torch.Tensor] = None,
86
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
87
- use_cache: Optional[bool] = False,
88
- output_attentions: Optional[bool] = False,
89
- cache_position: Optional[torch.LongTensor] = None,
90
- ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
91
- query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
92
+ input_ids: torch.LongTensor = None,
93
+ attention_mask: Optional[torch.Tensor] = None,
94
+ position_ids: Optional[torch.LongTensor] = None,
95
+ past_key_value: Optional[RebelDynamicCache_4D] = None,
96
+ batch_ids: Optional[torch.LongTensor] = None,
97
+ forward_dict: Optional[Dict[str, classmethod]] = None,
98
+ ) -> BaseModelOutputWithPast:
99
+ b_size, q_len = input_ids.shape
100
+ inputs_embeds = self.wte(input_ids)
92
101
 
93
- query = self._split_heads(query, self.num_heads, self.head_dim)
94
- key = self._split_heads(key, self.num_heads, self.head_dim)
95
- value = self._split_heads(value, self.num_heads, self.head_dim)
102
+ if position_ids.shape[0] > 1:
103
+ position_embeds = []
104
+ for b_idx in range(b_size):
105
+ position_embed = self.wpe(position_ids[b_idx])
106
+ # position_embed = position_embed.dtype(inputs_embeds.dtype)
107
+ position_embeds.append(position_embed)
96
108
 
97
- if layer_past is not None:
98
- past_key, past_value = layer_past
99
- query_length = query.shape[-2]
109
+ position_embeds = torch.cat(position_embeds, dim=0).unsqueeze(1)
110
+ else:
111
+ position_embeds = self.wpe(position_ids)
100
112
 
101
- key = torch.slice_scatter(past_key, key, dim=2, start=cache_position, end=cache_position + query_length)
102
- value = torch.slice_scatter(
103
- past_value, value, dim=2, start=cache_position, end=cache_position + query_length
104
- )
113
+ hidden_states = inputs_embeds + position_embeds
105
114
 
106
- present = (key, value)
107
- attn_output, _ = _GPT2Attention._attn(self, query, key, value, attention_mask, head_mask)
115
+ # GPT2Attention mask.
116
+ # Here we assume mask is causal mask, (batch, 1, query_length, key_length + query_length)
117
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
108
118
 
109
- attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
119
+ for layer_idx, block in enumerate(self.h):
120
+ hidden_states, updated_cache = forward_dict["model"](
121
+ block,
122
+ hidden_states,
123
+ layer_idx,
124
+ attention_mask=attention_mask,
125
+ past_key_value=past_key_value,
126
+ position_ids=position_ids,
127
+ batch_ids=batch_ids,
128
+ forward_dict=forward_dict,
129
+ )
110
130
 
111
- attn_output = self.c_proj(attn_output)
112
- attn_output = self.resid_dropout(attn_output)
131
+ hidden_states = self.ln_f(hidden_states)
132
+ output_shape = (-1,) + (q_len,) + (hidden_states.size(-1),)
133
+ hidden_states = hidden_states.view(output_shape)
113
134
 
114
- outputs = (attn_output, present)
135
+ # convert RebelDynamicCache to legacy Tuple[Tuple[torch.Tensor]]
136
+ next_cache = updated_cache.to_legacy_cache()
115
137
 
116
- return outputs
138
+ return BaseModelOutputWithPast(
139
+ last_hidden_state=hidden_states,
140
+ past_key_values=next_cache,
141
+ )
117
142
 
118
143
 
119
- class _GPT2Block(GPT2Block):
144
+ class _GPT2Block:
120
145
  def forward(
121
146
  self,
122
147
  hidden_states: Optional[Tuple[torch.FloatTensor]],
123
- layer_past: Optional[Tuple[torch.Tensor]] = None,
148
+ layer_idx: int,
124
149
  attention_mask: Optional[torch.FloatTensor] = None,
125
- head_mask: Optional[torch.FloatTensor] = None,
126
- encoder_hidden_states: Optional[torch.Tensor] = None,
127
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
128
- use_cache: Optional[bool] = False,
129
- output_attentions: Optional[bool] = False,
130
- cache_position: Optional[torch.LongTensor] = None,
131
- ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
150
+ position_ids: Optional[torch.LongTensor] = None,
151
+ past_key_value: Optional[RebelDynamicCache_4D] = None,
152
+ batch_ids: Optional[torch.LongTensor] = None,
153
+ forward_dict: Optional[Dict[str, classmethod]] = None,
154
+ **kwargs,
155
+ ) -> Tuple[torch.Tensor, RebelDynamicCache_4D]:
132
156
  residual = hidden_states
133
157
  hidden_states = self.ln_1(hidden_states)
134
158
 
135
- attn_outputs = _GPT2Attention.forward(
159
+ hidden_states, k, v = forward_dict["decoder_layer"](
136
160
  self.attn,
137
- hidden_states,
138
- layer_past=layer_past,
161
+ hidden_states=hidden_states,
139
162
  attention_mask=attention_mask,
140
- head_mask=head_mask,
141
- cache_position=cache_position,
163
+ position_ids=position_ids,
164
+ past_key_value=past_key_value,
165
+ batch_index=batch_ids,
142
166
  )
167
+ past_key_value.assign(k, v, layer_idx)
143
168
 
144
- attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
145
- outputs = attn_outputs[1:]
146
169
  # residual connection
147
- hidden_states = attn_output + residual
170
+ hidden_states = residual + hidden_states
148
171
 
149
172
  residual = hidden_states
150
173
  hidden_states = self.ln_2(hidden_states)
151
- feed_forward_hidden_states = self.mlp(hidden_states)
152
- # residual connection
153
- hidden_states = residual + feed_forward_hidden_states
154
-
155
- outputs = (hidden_states,) + outputs
156
- return outputs # hidden_states, present, (attentions, cross_attentions)
174
+ hidden_states = self.mlp(hidden_states)
175
+ hidden_states = residual + hidden_states
157
176
 
177
+ return hidden_states, past_key_value
158
178
 
159
- class _GPT2Model(GPT2Model):
160
- def forward(
161
- self,
162
- input_ids: Optional[torch.LongTensor] = None,
163
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
164
- attention_mask: Optional[torch.FloatTensor] = None,
165
- position_ids: Optional[torch.LongTensor] = None,
166
- head_mask: Optional[torch.FloatTensor] = None,
167
- cache_position: Optional[torch.LongTensor] = None,
168
- ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
169
- input_shape = input_ids.size()
170
179
 
171
- if position_ids is None:
172
- # force dtype to torch.long -> torch.int32 (to match cache_position)
173
- position_ids = torch.arange(0, input_shape[-1], dtype=torch.int32) + cache_position
174
- position_ids = position_ids.unsqueeze(0)
175
-
176
- # GPT2Attention mask.
177
- # Here we assume mask is causal mask, (batch, 1, query_length, key_length + query_length)
178
- attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
179
-
180
- # Prepare head mask if needed
181
- # 1.0 in head_mask indicate we keep the head
182
- # attention_probs has shape bsz x n_heads x N x N
183
- # head_mask has shape n_layer x batch x n_heads x N x N
184
- head_mask = self.get_head_mask(head_mask, self.config.n_layer)
185
-
186
- inputs_embeds = self.wte(input_ids)
187
- position_embeds = self.wpe(position_ids)
188
- hidden_states = inputs_embeds + position_embeds
189
-
190
- hidden_states = self.drop(hidden_states)
191
-
192
- output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
180
+ class _GPT2Attention:
181
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
182
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
193
183
 
194
- presents = ()
195
- for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
196
- outputs = _GPT2Block.forward(
197
- block,
198
- hidden_states,
199
- layer_past=layer_past,
200
- attention_mask=attention_mask,
201
- head_mask=head_mask[i],
202
- cache_position=cache_position,
184
+ if self.scale_attn_weights:
185
+ attn_weights = attn_weights / torch.full(
186
+ [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
203
187
  )
204
- hidden_states = outputs[0]
205
188
 
206
- presents = presents + (outputs[1],)
189
+ # Layer-wise attention scaling
190
+ if self.scale_attn_by_inverse_layer_idx:
191
+ attn_weights = attn_weights / float(self.layer_idx + 1)
207
192
 
208
- hidden_states = self.ln_f(hidden_states)
209
- hidden_states = hidden_states.view(output_shape)
210
- return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=presents)
193
+ # -------------------
194
+ # Below are deleted since "where" op does not supported on RBLN graph.
195
+ # -------------------
196
+ # if not self.is_cross_attention:
197
+ # # if only "normal" attention layer implements causal mask
198
+ # query_length, key_length = query.size(-2), key.size(-2)
199
+ # causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
200
+ # mask_value = torch.finfo(attn_weights.dtype).min
201
+ # # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
202
+ # # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
203
+ # mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
204
+ # attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
211
205
 
206
+ # Apply the attention mask
207
+ attn_weights.view(
208
+ -1,
209
+ )
210
+ attn_weights = attn_weights + attention_mask
211
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
212
+ attn_output = torch.matmul(attn_weights, value)
212
213
 
213
- class GPT2LMHeadModelWrapper(torch.nn.Module):
214
- def __init__(self, gpt):
215
- super().__init__()
216
- self.model = gpt
214
+ return attn_output, attn_weights
217
215
 
218
216
  def forward(
219
217
  self,
220
- input_ids: torch.Tensor,
221
- past_key_values: torch.Tensor,
222
- attention_mask: torch.Tensor,
223
- cache_position: torch.LongTensor,
224
- ):
225
- kv_cache = []
226
- for i in range(self.model.config.n_layer):
227
- kv_cache.append((past_key_values[i, 0], past_key_values[i, 1]))
228
-
229
- transformer_outputs = _GPT2Model.forward(
230
- self.model.transformer,
231
- input_ids=input_ids,
232
- past_key_values=kv_cache,
233
- attention_mask=attention_mask,
234
- cache_position=cache_position,
235
- )
236
-
237
- hidden_states = transformer_outputs[0]
218
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
219
+ attention_mask: Optional[torch.FloatTensor] = None,
220
+ past_key_value: Optional[RebelDynamicCache_4D] = None,
221
+ batch_index: Optional[int] = None,
222
+ **kwargs,
223
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
224
+ bsz, q_len, _ = hidden_states.size()
225
+ query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
238
226
 
239
- # TODO : Use query_length here to pick last logit
240
- # batch_size, sequence_length = hidden_states.shape[:2]
241
- # hidden_states = hidden_states.view(batch_size * sequence_length, -1)
242
- # hidden_states = torch.nn.functional.embedding(query_length, hidden_states)
243
- # hidden_states = hidden_states.view(batch_size, 1, -1)
227
+ querys = self._split_heads(query, self.num_heads, self.head_dim) # (batch, head, seq_length, head_features)
228
+ keys = self._split_heads(key, self.num_heads, self.head_dim)
229
+ values = self._split_heads(value, self.num_heads, self.head_dim)
230
+
231
+ # Decoder
232
+ if (batch_index is None or batch_index == -1) and bsz > 1:
233
+ all_keys = []
234
+ all_values = []
235
+ all_attn_output = []
236
+
237
+ for b in range(bsz):
238
+ query = querys[b].unsqueeze(0)
239
+ attn_mask = attention_mask[b].unsqueeze(0)
240
+ key = keys[b].unsqueeze(0)
241
+ value = values[b].unsqueeze(0)
242
+
243
+ key, value = past_key_value.update(
244
+ key,
245
+ value,
246
+ self.layer_idx,
247
+ b,
248
+ )
249
+
250
+ attn_output, _ = _GPT2Attention._attn(self, query, key, value, attn_mask)
251
+ attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
252
+
253
+ all_keys.append(key)
254
+ all_values.append(value)
255
+ all_attn_output.append(attn_output)
256
+
257
+ keys = torch.cat(all_keys, dim=0)
258
+ values = torch.cat(all_values, dim=0)
259
+ attn_output = torch.cat(all_attn_output, dim=0)
260
+
261
+ # Prefill
262
+ else:
263
+ if batch_index is None or batch_index == -1:
264
+ batch_index = 0
265
+
266
+ keys, values = past_key_value.update(
267
+ keys,
268
+ values,
269
+ self.layer_idx,
270
+ batch_index,
271
+ read_first_step=True,
272
+ )
244
273
 
245
- lm_logits = self.model.lm_head(hidden_states)
246
- kv_cache = transformer_outputs[1]
274
+ attn_output, _ = _GPT2Attention._attn(self, querys, keys, values, attention_mask)
275
+ attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
247
276
 
248
- past_key_values = []
249
- for i in range(self.model.config.n_layer):
250
- past_key_values.append(torch.stack(kv_cache[i], dim=0))
251
- past_key_values = torch.stack(past_key_values, dim=0)
277
+ attn_output = self.c_proj(attn_output)
252
278
 
253
- return lm_logits, past_key_values
279
+ return attn_output, keys, values