cehrgpt 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. cehrgpt/analysis/htn_treatment_pathway.py +546 -0
  2. cehrgpt/analysis/treatment_pathway/__init__.py +0 -0
  3. cehrgpt/analysis/treatment_pathway/depression_treatment_pathway.py +94 -0
  4. cehrgpt/analysis/treatment_pathway/diabetes_treatment_pathway.py +94 -0
  5. cehrgpt/analysis/treatment_pathway/htn_treatment_pathway.py +94 -0
  6. cehrgpt/analysis/treatment_pathway/treatment_pathway.py +631 -0
  7. cehrgpt/data/cehrgpt_data_processor.py +549 -0
  8. cehrgpt/data/hf_cehrgpt_dataset.py +4 -0
  9. cehrgpt/data/hf_cehrgpt_dataset_collator.py +286 -629
  10. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +60 -14
  11. cehrgpt/generation/cehrgpt_conditional_generation.py +316 -0
  12. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +35 -15
  13. cehrgpt/generation/omop_converter_batch.py +11 -4
  14. cehrgpt/gpt_utils.py +73 -3
  15. cehrgpt/models/activations.py +27 -0
  16. cehrgpt/models/config.py +6 -2
  17. cehrgpt/models/gpt2.py +560 -0
  18. cehrgpt/models/hf_cehrgpt.py +193 -459
  19. cehrgpt/models/tokenization_hf_cehrgpt.py +380 -50
  20. cehrgpt/omop/ontology.py +154 -0
  21. cehrgpt/runners/data_utils.py +17 -6
  22. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +33 -79
  23. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +48 -44
  24. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +58 -34
  25. cehrgpt/runners/hyperparameter_search_util.py +180 -69
  26. cehrgpt/runners/sample_packing_trainer.py +11 -2
  27. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +27 -31
  28. cehrgpt-0.1.3.dist-info/METADATA +238 -0
  29. {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/RECORD +33 -22
  30. cehrgpt-0.1.1.dist-info/METADATA +0 -115
  31. /cehrgpt/tools/{merge_synthetic_real_dataasets.py → merge_synthetic_real_datasets.py} +0 -0
  32. {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/WHEEL +0 -0
  33. {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/licenses/LICENSE +0 -0
  34. {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/top_level.txt +0 -0
cehrgpt/models/gpt2.py ADDED
@@ -0,0 +1,560 @@
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch.nn.functional as f
4
+ import torch.utils.checkpoint
5
+ from torch import nn
6
+ from transformers.activations import ACT2FN
7
+ from transformers.models.gpt2.modeling_gpt2 import Conv1D, GPT2Attention
8
+ from transformers.utils import is_flash_attn_2_available, logging
9
+
10
+ if is_flash_attn_2_available():
11
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
12
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
13
+
14
+ from cehrgpt.models.activations import RMSNorm
15
+
16
+ logger = logging.get_logger("transformers")
17
+
18
+
19
+ def is_sample_pack(attention_mask: torch.Tensor) -> bool:
20
+ """
21
+ Determines whether any sequence in the batch is likely sample-packed.
22
+
23
+ A sample-packed sequence is one where there are non-padding (1) tokens
24
+ after a padding (0) token, indicating multiple sequences packed together
25
+ with padding as a separator.
26
+
27
+ Args:
28
+ attention_mask (torch.Tensor): A tensor of shape (batch_size, seq_len)
29
+ where 1 indicates a real token and 0 indicates padding.
30
+
31
+ Returns:
32
+ bool: True if any sample in the batch is sample-packed, False otherwise.
33
+ """
34
+
35
+ # If the attention_maks is left padded, we will flip it so we can use the same logic below
36
+ if (attention_mask[:, 0] == 0).any():
37
+ attention_mask = attention_mask.flip(dims=[1])
38
+
39
+ nonzero_counts = attention_mask.sum(dim=1)
40
+ max_token_positions = torch.argmax(
41
+ attention_mask.to(torch.int32).flip(dims=[1]), dim=1
42
+ )
43
+ max_indices = attention_mask.shape[1] - 1 - max_token_positions
44
+ return torch.any(nonzero_counts < (max_indices + 1)).item()
45
+
46
+
47
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
48
+ def _get_unpad_data(attention_mask):
49
+ # This infers sample packing
50
+ if is_sample_pack(attention_mask):
51
+ # Assume input: attention_mask shape = (batch, seq_len)
52
+ attention_mask = attention_mask.flatten() # shape: (seq_len,)
53
+
54
+ # Compute max_index of the last non-zero element
55
+ nonzero = torch.nonzero(attention_mask, as_tuple=False).flatten()
56
+ max_index = nonzero[-1].item()
57
+
58
+ # Pad the truncated attention mask
59
+ padded_attention_mask = f.pad(attention_mask[: max_index + 1], (0, 1), value=0)
60
+
61
+ # Indices of all tokens
62
+ indices = torch.nonzero(attention_mask, as_tuple=False).flatten()
63
+
64
+ # Find where 0s occur (segment boundaries)
65
+ cumsum_seqlens_in_batch = torch.cumsum(padded_attention_mask, dim=0)[
66
+ padded_attention_mask == 0
67
+ ]
68
+
69
+ # Compute seqlens per segment
70
+ seqlens_in_batch = (
71
+ cumsum_seqlens_in_batch
72
+ - f.pad(cumsum_seqlens_in_batch, (1, 0), value=0)[:-1]
73
+ ).to(torch.int)
74
+
75
+ max_seqlen_in_batch = (
76
+ seqlens_in_batch.max().item() if seqlens_in_batch.numel() > 0 else 0
77
+ )
78
+ cu_seqlens = f.pad(cumsum_seqlens_in_batch, (1, 0)).to(torch.int)
79
+ else:
80
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
81
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
82
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
83
+ cu_seqlens = f.pad(
84
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
85
+ )
86
+
87
+ return (
88
+ indices,
89
+ cu_seqlens,
90
+ max_seqlen_in_batch,
91
+ )
92
+
93
+
94
+ class RotaryPositionEmbedding(nn.Module):
95
+ def __init__(self, dim: int):
96
+ super().__init__()
97
+ self.dim = dim
98
+ self.inv_freq = 1.0 / (10000 ** (torch.linspace(0, 2, steps=dim // 2))).reshape(
99
+ 1, 1, dim // 2
100
+ )
101
+
102
+ def forward(self, x: torch.Tensor, time: torch.Tensor) -> torch.Tensor:
103
+ if time.ndim == 2:
104
+ time = time[..., None]
105
+ t = self.inv_freq.to(time.device) * time
106
+ sin, cos = torch.sin(t), torch.cos(t)
107
+ sin = torch.stack((sin, sin), dim=-1).reshape(x.shape)
108
+ cos = torch.stack((cos, cos), dim=-1).reshape(x.shape)
109
+ flat_x = x.reshape(-1, x.shape[-1])
110
+ x1 = flat_x[:, ::2]
111
+ x2 = flat_x[:, 1::2]
112
+ return (x * cos) + (torch.stack((-x2, x1), dim=-1).reshape(x.shape) * sin)
113
+
114
+
115
+ class GPT2AttentionRoPE(GPT2Attention):
116
+ """
117
+ GPT2FlashAttention inherits from `GPT2Attention`.
118
+
119
+ The primary change is in the forward pass, where it correctly
120
+ calls the public API of flash attention and handles padding tokens.
121
+ """
122
+
123
+ def __init__(
124
+ self, config, is_cross_attention=False, layer_idx=None, apply_rotary=False
125
+ ):
126
+ super().__init__(config, is_cross_attention, layer_idx)
127
+ self.apply_rotary = apply_rotary
128
+ if self.apply_rotary:
129
+ self.rope = RotaryPositionEmbedding(config.hidden_size)
130
+
131
+ def forward(
132
+ self,
133
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
134
+ position_ids: Optional[Tuple[torch.FloatTensor]] = None,
135
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
136
+ attention_mask: Optional[torch.FloatTensor] = None,
137
+ head_mask: Optional[torch.FloatTensor] = None,
138
+ encoder_hidden_states: Optional[torch.Tensor] = None,
139
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
140
+ use_cache: Optional[bool] = False,
141
+ output_attentions: Optional[bool] = False,
142
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
143
+
144
+ if encoder_hidden_states is not None:
145
+ if not hasattr(self, "q_attn"):
146
+ raise ValueError(
147
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
148
+ "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
149
+ )
150
+
151
+ query = self.q_attn(hidden_states)
152
+ key, value = self.c_attn(encoder_hidden_states).split(
153
+ self.split_size, dim=2
154
+ )
155
+ attention_mask = encoder_attention_mask
156
+ else:
157
+ query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
158
+
159
+ if self.apply_rotary and position_ids is not None:
160
+ query = self.rope(query, position_ids)
161
+ key = self.rope(key, position_ids)
162
+ # value = self.rope(value, position_ids)
163
+
164
+ query = self._split_heads(query, self.num_heads, self.head_dim)
165
+ key = self._split_heads(key, self.num_heads, self.head_dim)
166
+ value = self._split_heads(value, self.num_heads, self.head_dim)
167
+
168
+ if layer_past is not None:
169
+ past_key, past_value = layer_past
170
+ key = torch.cat((past_key, key), dim=-2)
171
+ value = torch.cat((past_value, value), dim=-2)
172
+
173
+ if use_cache is True:
174
+ present = (key, value)
175
+ else:
176
+ present = None
177
+
178
+ if self.reorder_and_upcast_attn:
179
+ attn_output, attn_weights = self._upcast_and_reordered_attn(
180
+ query, key, value, attention_mask, head_mask
181
+ )
182
+ else:
183
+ attn_output, attn_weights = self._attn(
184
+ query, key, value, attention_mask, head_mask
185
+ )
186
+
187
+ attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
188
+ attn_output = self.c_proj(attn_output)
189
+ attn_output = self.resid_dropout(attn_output)
190
+
191
+ outputs = (attn_output, present)
192
+ if output_attentions:
193
+ outputs += (attn_weights,)
194
+
195
+ return outputs # a, present, (attentions)
196
+
197
+
198
+ class GPT2FlashAttention(GPT2Attention):
199
+ """
200
+ GPT2FlashAttention inherits from `GPT2Attention`.
201
+
202
+ The primary change is in the forward pass, where it correctly
203
+ calls the public API of flash attention and handles padding tokens.
204
+ """
205
+
206
+ def __init__(
207
+ self, config, is_cross_attention=False, layer_idx=None, apply_rotary=False
208
+ ):
209
+ super().__init__(config, is_cross_attention, layer_idx)
210
+ self.apply_rotary = apply_rotary
211
+ if self.apply_rotary:
212
+ self.rope = RotaryPositionEmbedding(config.hidden_size)
213
+
214
+ def forward(
215
+ self,
216
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
217
+ position_ids: Optional[Tuple[torch.FloatTensor]] = None,
218
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
219
+ attention_mask: Optional[torch.FloatTensor] = None,
220
+ head_mask: Optional[torch.FloatTensor] = None,
221
+ encoder_hidden_states: Optional[torch.Tensor] = None,
222
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
223
+ use_cache: Optional[bool] = False,
224
+ output_attentions: Optional[bool] = False,
225
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
226
+
227
+ # Prepare query, key, and value
228
+ if encoder_hidden_states is not None:
229
+ if not hasattr(self, "q_attn"):
230
+ raise ValueError(
231
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
232
+ "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
233
+ )
234
+
235
+ query = self.q_attn(hidden_states)
236
+ key, value = self.c_attn(encoder_hidden_states).split(
237
+ self.split_size, dim=2
238
+ )
239
+ attention_mask = encoder_attention_mask
240
+ else:
241
+ query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
242
+
243
+ if self.apply_rotary and position_ids is not None:
244
+ query = self.rope(query, position_ids)
245
+ key = self.rope(key, position_ids)
246
+ # value = self.rope(value, position_ids)
247
+
248
+ query = self._split_heads(query, self.num_heads, self.head_dim)
249
+ key = self._split_heads(key, self.num_heads, self.head_dim)
250
+ value = self._split_heads(value, self.num_heads, self.head_dim)
251
+
252
+ if layer_past is not None:
253
+ past_key, past_value = layer_past
254
+ key = torch.cat((past_key, key), dim=-2)
255
+ value = torch.cat((past_value, value), dim=-2)
256
+
257
+ if use_cache is True:
258
+ present = (key, value)
259
+ else:
260
+ present = None
261
+
262
+ # Apply Flash Attention Forward
263
+ if self.reorder_and_upcast_attn:
264
+ attn_output, attn_weights = self._upcast_and_reordered_attn(
265
+ query, key, value, attention_mask, head_mask
266
+ )
267
+ else:
268
+ # Flash Attention forward pass
269
+ attn_output = self._flash_attention_forward(
270
+ query,
271
+ key,
272
+ value,
273
+ attention_mask,
274
+ query.size(-2),
275
+ self.attn_dropout.p,
276
+ softmax_scale=None,
277
+ )
278
+
279
+ # Merge heads and project back to hidden size
280
+ attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
281
+ attn_output = self.c_proj(attn_output)
282
+ attn_output = self.resid_dropout(attn_output)
283
+
284
+ outputs = (attn_output, present)
285
+ if output_attentions:
286
+ outputs += (attn_weights,)
287
+
288
+ return outputs
289
+
290
+ def _flash_attention_forward(
291
+ self,
292
+ query_states,
293
+ key_states,
294
+ value_states,
295
+ attention_mask,
296
+ query_length,
297
+ dropout=0.0,
298
+ softmax_scale=None,
299
+ ):
300
+ """
301
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token.
302
+
303
+ first unpad the input, then computes the attention scores and pad the final attention scores.
304
+ Args:
305
+ query_states (`torch.Tensor`):
306
+ Input query states to be passed to Flash Attention API
307
+ key_states (`torch.Tensor`):
308
+ Input key states to be passed to Flash Attention API
309
+ value_states (`torch.Tensor`):
310
+ Input value states to be passed to Flash Attention API
311
+ attention_mask (`torch.Tensor`):
312
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
313
+ position of padding tokens and 1 for the position of non-padding tokens.
314
+ dropout (`int`, *optional*):
315
+ Attention dropout
316
+ softmax_scale (`float`, *optional*):
317
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
318
+ """
319
+
320
+ # Flash attention requires the input to have the shape
321
+ # batch_size x seq_length x head_dim x hidden_dim
322
+ # therefore we just need to keep the original shape
323
+ dtype = query_states.dtype
324
+ query_states = query_states.permute(0, 2, 1, 3).contiguous().to(torch.bfloat16)
325
+ key_states = key_states.permute(0, 2, 1, 3).contiguous().to(torch.bfloat16)
326
+ value_states = value_states.permute(0, 2, 1, 3).contiguous().to(torch.bfloat16)
327
+
328
+ # Contains at least one padding token in the sequence
329
+ if attention_mask is not None:
330
+ batch_size = query_states.shape[0]
331
+
332
+ (
333
+ query_states,
334
+ key_states,
335
+ value_states,
336
+ indices_q,
337
+ cu_seq_lens,
338
+ max_seq_lens,
339
+ ) = self._upad_input(
340
+ query_states, key_states, value_states, attention_mask, query_length
341
+ )
342
+
343
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
344
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
345
+
346
+ attn_output_unpad = flash_attn_varlen_func(
347
+ query_states,
348
+ key_states,
349
+ value_states,
350
+ cu_seqlens_q=cu_seqlens_q,
351
+ cu_seqlens_k=cu_seqlens_k,
352
+ max_seqlen_q=max_seqlen_in_batch_q,
353
+ max_seqlen_k=max_seqlen_in_batch_k,
354
+ dropout_p=dropout,
355
+ softmax_scale=softmax_scale,
356
+ causal=True,
357
+ )
358
+ # (batch, seq_length, n_heads, head_dim)
359
+ attn_output = pad_input(
360
+ attn_output_unpad, indices_q, batch_size, query_length
361
+ )
362
+ else:
363
+ attn_output = flash_attn_func(
364
+ query_states,
365
+ key_states,
366
+ value_states,
367
+ dropout,
368
+ softmax_scale=softmax_scale,
369
+ causal=self.is_causal,
370
+ )
371
+ # re-order the tensor back to (batch, n_heads, seq_length, head_dim)
372
+ return attn_output.permute(0, 2, 1, 3).contiguous().to(dtype)
373
+
374
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
375
+ def _upad_input(
376
+ self, query_layer, key_layer, value_layer, attention_mask, query_length
377
+ ):
378
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
379
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
380
+
381
+ key_layer = index_first_axis(
382
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
383
+ indices_k,
384
+ )
385
+ value_layer = index_first_axis(
386
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
387
+ indices_k,
388
+ )
389
+ if query_length == kv_seq_len:
390
+ query_layer = index_first_axis(
391
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
392
+ indices_k,
393
+ )
394
+ cu_seqlens_q = cu_seqlens_k
395
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
396
+ indices_q = indices_k
397
+ elif query_length == 1:
398
+ max_seqlen_in_batch_q = 1
399
+ cu_seqlens_q = torch.arange(
400
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
401
+ ) # There is a memcpy here, that is very bad.
402
+ indices_q = cu_seqlens_q[:-1]
403
+ query_layer = query_layer.squeeze(1)
404
+ else:
405
+ # The -q_len: slice assumes left padding.
406
+ attention_mask = attention_mask[:, -query_length:]
407
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
408
+ query_layer, attention_mask
409
+ )
410
+
411
+ return (
412
+ query_layer,
413
+ key_layer,
414
+ value_layer,
415
+ indices_q,
416
+ (cu_seqlens_q, cu_seqlens_k),
417
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
418
+ )
419
+
420
+
421
+ class LlamaMLP(nn.Module):
422
+ def __init__(self, intermediate_size, config):
423
+ super().__init__()
424
+ self.config = config
425
+ self.hidden_size = config.hidden_size
426
+ self.intermediate_size = intermediate_size
427
+ self.gate_proj = nn.Linear(
428
+ self.hidden_size, self.intermediate_size, bias=config.mlp_bias
429
+ )
430
+ self.up_proj = nn.Linear(
431
+ self.hidden_size, self.intermediate_size, bias=config.mlp_bias
432
+ )
433
+ self.down_proj = nn.Linear(
434
+ self.intermediate_size, self.hidden_size, bias=config.mlp_bias
435
+ )
436
+ self.act_fn = ACT2FN[config.activation_function]
437
+
438
+ def forward(self, x: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
439
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
440
+ return down_proj
441
+
442
+
443
+ class GPT2MLP(nn.Module):
444
+ def __init__(self, intermediate_size, config):
445
+ super().__init__()
446
+ embed_dim = config.hidden_size
447
+ self.c_fc = Conv1D(intermediate_size, embed_dim)
448
+ self.c_proj = Conv1D(embed_dim, intermediate_size)
449
+ self.act = ACT2FN[config.activation_function]
450
+ self.dropout = nn.Dropout(config.resid_pdrop)
451
+
452
+ def forward(
453
+ self, hidden_states: Optional[Tuple[torch.FloatTensor]]
454
+ ) -> torch.FloatTensor:
455
+ hidden_states = self.c_fc(hidden_states)
456
+ hidden_states = self.act(hidden_states)
457
+ hidden_states = self.c_proj(hidden_states)
458
+ hidden_states = self.dropout(hidden_states)
459
+ return hidden_states
460
+
461
+
462
+ class GPT2Block(nn.Module):
463
+ def __init__(self, config, layer_idx=None):
464
+ super().__init__()
465
+ hidden_size = config.hidden_size
466
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
467
+ attention_class = (
468
+ GPT2FlashAttention
469
+ if getattr(config, "_attn_implementation", "eager") == "flash_attention_2"
470
+ else GPT2AttentionRoPE
471
+ )
472
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
473
+ self.attn = attention_class(
474
+ config=config, layer_idx=layer_idx, apply_rotary=config.apply_rotary
475
+ )
476
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
477
+
478
+ if config.add_cross_attention:
479
+ self.crossattention = attention_class(
480
+ config=config, is_cross_attention=True, layer_idx=layer_idx
481
+ )
482
+ self.ln_cross_attn = nn.LayerNorm(
483
+ hidden_size, eps=config.layer_norm_epsilon
484
+ )
485
+
486
+ decoder_mlp_function = getattr(config, "decoder_mlp", "GPT2MLP")
487
+ if decoder_mlp_function == "GPT2MLP":
488
+ self.mlp = GPT2MLP(inner_dim, config)
489
+ elif getattr(config, "decoder_mlp", "GPT2Block") == "LlamaMLP":
490
+ self.mlp = LlamaMLP(inner_dim, config)
491
+ else:
492
+ raise RuntimeError("You must set decoder_mlp to one of (GPT2MLP, LlamaMLP)")
493
+
494
+ def forward(
495
+ self,
496
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
497
+ position_ids: Optional[Tuple[torch.FloatTensor]] = None,
498
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
499
+ attention_mask: Optional[torch.FloatTensor] = None,
500
+ head_mask: Optional[torch.FloatTensor] = None,
501
+ encoder_hidden_states: Optional[torch.Tensor] = None,
502
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
503
+ use_cache: Optional[bool] = False,
504
+ output_attentions: Optional[bool] = False,
505
+ ) -> Union[
506
+ Tuple[torch.Tensor],
507
+ Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]],
508
+ ]:
509
+ residual = hidden_states
510
+ hidden_states = self.ln_1(hidden_states)
511
+ attn_outputs = self.attn(
512
+ hidden_states,
513
+ position_ids=position_ids,
514
+ layer_past=layer_past,
515
+ attention_mask=attention_mask,
516
+ head_mask=head_mask,
517
+ use_cache=use_cache,
518
+ output_attentions=output_attentions,
519
+ )
520
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
521
+ outputs = attn_outputs[1:]
522
+ # residual connection
523
+ hidden_states = attn_output + residual
524
+
525
+ if encoder_hidden_states is not None:
526
+ # add one self-attention block for cross-attention
527
+ if not hasattr(self, "crossattention"):
528
+ raise ValueError(
529
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
530
+ "cross-attention layers by setting `config.add_cross_attention=True`"
531
+ )
532
+ residual = hidden_states
533
+ hidden_states = self.ln_cross_attn(hidden_states)
534
+ cross_attn_outputs = self.crossattention(
535
+ hidden_states,
536
+ attention_mask=attention_mask,
537
+ head_mask=head_mask,
538
+ encoder_hidden_states=encoder_hidden_states,
539
+ encoder_attention_mask=encoder_attention_mask,
540
+ output_attentions=output_attentions,
541
+ )
542
+ attn_output = cross_attn_outputs[0]
543
+ # residual connection
544
+ hidden_states = residual + attn_output
545
+ outputs = (
546
+ outputs + cross_attn_outputs[2:]
547
+ ) # add cross attentions if we output attention weights
548
+
549
+ residual = hidden_states
550
+ hidden_states = self.ln_2(hidden_states)
551
+ feed_forward_hidden_states = self.mlp(hidden_states)
552
+ # residual connection
553
+ hidden_states = residual + feed_forward_hidden_states
554
+
555
+ if use_cache:
556
+ outputs = (hidden_states,) + outputs
557
+ else:
558
+ outputs = (hidden_states,) + outputs[1:]
559
+
560
+ return outputs # hidden_states, present, (attentions, cross_attentions)