cehrgpt 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/htn_treatment_pathway.py +546 -0
- cehrgpt/analysis/treatment_pathway/__init__.py +0 -0
- cehrgpt/analysis/treatment_pathway/depression_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/diabetes_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/htn_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/treatment_pathway.py +631 -0
- cehrgpt/data/cehrgpt_data_processor.py +549 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +4 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +286 -629
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +60 -14
- cehrgpt/generation/cehrgpt_conditional_generation.py +316 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +35 -15
- cehrgpt/generation/omop_converter_batch.py +11 -4
- cehrgpt/gpt_utils.py +73 -3
- cehrgpt/models/activations.py +27 -0
- cehrgpt/models/config.py +6 -2
- cehrgpt/models/gpt2.py +560 -0
- cehrgpt/models/hf_cehrgpt.py +193 -459
- cehrgpt/models/tokenization_hf_cehrgpt.py +380 -50
- cehrgpt/omop/ontology.py +154 -0
- cehrgpt/runners/data_utils.py +17 -6
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +33 -79
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +48 -44
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +58 -34
- cehrgpt/runners/hyperparameter_search_util.py +180 -69
- cehrgpt/runners/sample_packing_trainer.py +11 -2
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +27 -31
- cehrgpt-0.1.3.dist-info/METADATA +238 -0
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/RECORD +33 -22
- cehrgpt-0.1.1.dist-info/METADATA +0 -115
- /cehrgpt/tools/{merge_synthetic_real_dataasets.py → merge_synthetic_real_datasets.py} +0 -0
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/WHEEL +0 -0
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/top_level.txt +0 -0
cehrgpt/models/gpt2.py
ADDED
@@ -0,0 +1,560 @@
|
|
1
|
+
from typing import Optional, Tuple, Union
|
2
|
+
|
3
|
+
import torch.nn.functional as f
|
4
|
+
import torch.utils.checkpoint
|
5
|
+
from torch import nn
|
6
|
+
from transformers.activations import ACT2FN
|
7
|
+
from transformers.models.gpt2.modeling_gpt2 import Conv1D, GPT2Attention
|
8
|
+
from transformers.utils import is_flash_attn_2_available, logging
|
9
|
+
|
10
|
+
if is_flash_attn_2_available():
|
11
|
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
12
|
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
13
|
+
|
14
|
+
from cehrgpt.models.activations import RMSNorm
|
15
|
+
|
16
|
+
logger = logging.get_logger("transformers")
|
17
|
+
|
18
|
+
|
19
|
+
def is_sample_pack(attention_mask: torch.Tensor) -> bool:
|
20
|
+
"""
|
21
|
+
Determines whether any sequence in the batch is likely sample-packed.
|
22
|
+
|
23
|
+
A sample-packed sequence is one where there are non-padding (1) tokens
|
24
|
+
after a padding (0) token, indicating multiple sequences packed together
|
25
|
+
with padding as a separator.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
attention_mask (torch.Tensor): A tensor of shape (batch_size, seq_len)
|
29
|
+
where 1 indicates a real token and 0 indicates padding.
|
30
|
+
|
31
|
+
Returns:
|
32
|
+
bool: True if any sample in the batch is sample-packed, False otherwise.
|
33
|
+
"""
|
34
|
+
|
35
|
+
# If the attention_maks is left padded, we will flip it so we can use the same logic below
|
36
|
+
if (attention_mask[:, 0] == 0).any():
|
37
|
+
attention_mask = attention_mask.flip(dims=[1])
|
38
|
+
|
39
|
+
nonzero_counts = attention_mask.sum(dim=1)
|
40
|
+
max_token_positions = torch.argmax(
|
41
|
+
attention_mask.to(torch.int32).flip(dims=[1]), dim=1
|
42
|
+
)
|
43
|
+
max_indices = attention_mask.shape[1] - 1 - max_token_positions
|
44
|
+
return torch.any(nonzero_counts < (max_indices + 1)).item()
|
45
|
+
|
46
|
+
|
47
|
+
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
48
|
+
def _get_unpad_data(attention_mask):
|
49
|
+
# This infers sample packing
|
50
|
+
if is_sample_pack(attention_mask):
|
51
|
+
# Assume input: attention_mask shape = (batch, seq_len)
|
52
|
+
attention_mask = attention_mask.flatten() # shape: (seq_len,)
|
53
|
+
|
54
|
+
# Compute max_index of the last non-zero element
|
55
|
+
nonzero = torch.nonzero(attention_mask, as_tuple=False).flatten()
|
56
|
+
max_index = nonzero[-1].item()
|
57
|
+
|
58
|
+
# Pad the truncated attention mask
|
59
|
+
padded_attention_mask = f.pad(attention_mask[: max_index + 1], (0, 1), value=0)
|
60
|
+
|
61
|
+
# Indices of all tokens
|
62
|
+
indices = torch.nonzero(attention_mask, as_tuple=False).flatten()
|
63
|
+
|
64
|
+
# Find where 0s occur (segment boundaries)
|
65
|
+
cumsum_seqlens_in_batch = torch.cumsum(padded_attention_mask, dim=0)[
|
66
|
+
padded_attention_mask == 0
|
67
|
+
]
|
68
|
+
|
69
|
+
# Compute seqlens per segment
|
70
|
+
seqlens_in_batch = (
|
71
|
+
cumsum_seqlens_in_batch
|
72
|
+
- f.pad(cumsum_seqlens_in_batch, (1, 0), value=0)[:-1]
|
73
|
+
).to(torch.int)
|
74
|
+
|
75
|
+
max_seqlen_in_batch = (
|
76
|
+
seqlens_in_batch.max().item() if seqlens_in_batch.numel() > 0 else 0
|
77
|
+
)
|
78
|
+
cu_seqlens = f.pad(cumsum_seqlens_in_batch, (1, 0)).to(torch.int)
|
79
|
+
else:
|
80
|
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
81
|
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
82
|
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
83
|
+
cu_seqlens = f.pad(
|
84
|
+
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
85
|
+
)
|
86
|
+
|
87
|
+
return (
|
88
|
+
indices,
|
89
|
+
cu_seqlens,
|
90
|
+
max_seqlen_in_batch,
|
91
|
+
)
|
92
|
+
|
93
|
+
|
94
|
+
class RotaryPositionEmbedding(nn.Module):
|
95
|
+
def __init__(self, dim: int):
|
96
|
+
super().__init__()
|
97
|
+
self.dim = dim
|
98
|
+
self.inv_freq = 1.0 / (10000 ** (torch.linspace(0, 2, steps=dim // 2))).reshape(
|
99
|
+
1, 1, dim // 2
|
100
|
+
)
|
101
|
+
|
102
|
+
def forward(self, x: torch.Tensor, time: torch.Tensor) -> torch.Tensor:
|
103
|
+
if time.ndim == 2:
|
104
|
+
time = time[..., None]
|
105
|
+
t = self.inv_freq.to(time.device) * time
|
106
|
+
sin, cos = torch.sin(t), torch.cos(t)
|
107
|
+
sin = torch.stack((sin, sin), dim=-1).reshape(x.shape)
|
108
|
+
cos = torch.stack((cos, cos), dim=-1).reshape(x.shape)
|
109
|
+
flat_x = x.reshape(-1, x.shape[-1])
|
110
|
+
x1 = flat_x[:, ::2]
|
111
|
+
x2 = flat_x[:, 1::2]
|
112
|
+
return (x * cos) + (torch.stack((-x2, x1), dim=-1).reshape(x.shape) * sin)
|
113
|
+
|
114
|
+
|
115
|
+
class GPT2AttentionRoPE(GPT2Attention):
|
116
|
+
"""
|
117
|
+
GPT2FlashAttention inherits from `GPT2Attention`.
|
118
|
+
|
119
|
+
The primary change is in the forward pass, where it correctly
|
120
|
+
calls the public API of flash attention and handles padding tokens.
|
121
|
+
"""
|
122
|
+
|
123
|
+
def __init__(
|
124
|
+
self, config, is_cross_attention=False, layer_idx=None, apply_rotary=False
|
125
|
+
):
|
126
|
+
super().__init__(config, is_cross_attention, layer_idx)
|
127
|
+
self.apply_rotary = apply_rotary
|
128
|
+
if self.apply_rotary:
|
129
|
+
self.rope = RotaryPositionEmbedding(config.hidden_size)
|
130
|
+
|
131
|
+
def forward(
|
132
|
+
self,
|
133
|
+
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
134
|
+
position_ids: Optional[Tuple[torch.FloatTensor]] = None,
|
135
|
+
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
136
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
137
|
+
head_mask: Optional[torch.FloatTensor] = None,
|
138
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
139
|
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
140
|
+
use_cache: Optional[bool] = False,
|
141
|
+
output_attentions: Optional[bool] = False,
|
142
|
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
143
|
+
|
144
|
+
if encoder_hidden_states is not None:
|
145
|
+
if not hasattr(self, "q_attn"):
|
146
|
+
raise ValueError(
|
147
|
+
"If class is used as cross attention, the weights `q_attn` have to be defined. "
|
148
|
+
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
|
149
|
+
)
|
150
|
+
|
151
|
+
query = self.q_attn(hidden_states)
|
152
|
+
key, value = self.c_attn(encoder_hidden_states).split(
|
153
|
+
self.split_size, dim=2
|
154
|
+
)
|
155
|
+
attention_mask = encoder_attention_mask
|
156
|
+
else:
|
157
|
+
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
158
|
+
|
159
|
+
if self.apply_rotary and position_ids is not None:
|
160
|
+
query = self.rope(query, position_ids)
|
161
|
+
key = self.rope(key, position_ids)
|
162
|
+
# value = self.rope(value, position_ids)
|
163
|
+
|
164
|
+
query = self._split_heads(query, self.num_heads, self.head_dim)
|
165
|
+
key = self._split_heads(key, self.num_heads, self.head_dim)
|
166
|
+
value = self._split_heads(value, self.num_heads, self.head_dim)
|
167
|
+
|
168
|
+
if layer_past is not None:
|
169
|
+
past_key, past_value = layer_past
|
170
|
+
key = torch.cat((past_key, key), dim=-2)
|
171
|
+
value = torch.cat((past_value, value), dim=-2)
|
172
|
+
|
173
|
+
if use_cache is True:
|
174
|
+
present = (key, value)
|
175
|
+
else:
|
176
|
+
present = None
|
177
|
+
|
178
|
+
if self.reorder_and_upcast_attn:
|
179
|
+
attn_output, attn_weights = self._upcast_and_reordered_attn(
|
180
|
+
query, key, value, attention_mask, head_mask
|
181
|
+
)
|
182
|
+
else:
|
183
|
+
attn_output, attn_weights = self._attn(
|
184
|
+
query, key, value, attention_mask, head_mask
|
185
|
+
)
|
186
|
+
|
187
|
+
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
188
|
+
attn_output = self.c_proj(attn_output)
|
189
|
+
attn_output = self.resid_dropout(attn_output)
|
190
|
+
|
191
|
+
outputs = (attn_output, present)
|
192
|
+
if output_attentions:
|
193
|
+
outputs += (attn_weights,)
|
194
|
+
|
195
|
+
return outputs # a, present, (attentions)
|
196
|
+
|
197
|
+
|
198
|
+
class GPT2FlashAttention(GPT2Attention):
|
199
|
+
"""
|
200
|
+
GPT2FlashAttention inherits from `GPT2Attention`.
|
201
|
+
|
202
|
+
The primary change is in the forward pass, where it correctly
|
203
|
+
calls the public API of flash attention and handles padding tokens.
|
204
|
+
"""
|
205
|
+
|
206
|
+
def __init__(
|
207
|
+
self, config, is_cross_attention=False, layer_idx=None, apply_rotary=False
|
208
|
+
):
|
209
|
+
super().__init__(config, is_cross_attention, layer_idx)
|
210
|
+
self.apply_rotary = apply_rotary
|
211
|
+
if self.apply_rotary:
|
212
|
+
self.rope = RotaryPositionEmbedding(config.hidden_size)
|
213
|
+
|
214
|
+
def forward(
|
215
|
+
self,
|
216
|
+
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
217
|
+
position_ids: Optional[Tuple[torch.FloatTensor]] = None,
|
218
|
+
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
219
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
220
|
+
head_mask: Optional[torch.FloatTensor] = None,
|
221
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
222
|
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
223
|
+
use_cache: Optional[bool] = False,
|
224
|
+
output_attentions: Optional[bool] = False,
|
225
|
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
226
|
+
|
227
|
+
# Prepare query, key, and value
|
228
|
+
if encoder_hidden_states is not None:
|
229
|
+
if not hasattr(self, "q_attn"):
|
230
|
+
raise ValueError(
|
231
|
+
"If class is used as cross attention, the weights `q_attn` have to be defined. "
|
232
|
+
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
|
233
|
+
)
|
234
|
+
|
235
|
+
query = self.q_attn(hidden_states)
|
236
|
+
key, value = self.c_attn(encoder_hidden_states).split(
|
237
|
+
self.split_size, dim=2
|
238
|
+
)
|
239
|
+
attention_mask = encoder_attention_mask
|
240
|
+
else:
|
241
|
+
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
242
|
+
|
243
|
+
if self.apply_rotary and position_ids is not None:
|
244
|
+
query = self.rope(query, position_ids)
|
245
|
+
key = self.rope(key, position_ids)
|
246
|
+
# value = self.rope(value, position_ids)
|
247
|
+
|
248
|
+
query = self._split_heads(query, self.num_heads, self.head_dim)
|
249
|
+
key = self._split_heads(key, self.num_heads, self.head_dim)
|
250
|
+
value = self._split_heads(value, self.num_heads, self.head_dim)
|
251
|
+
|
252
|
+
if layer_past is not None:
|
253
|
+
past_key, past_value = layer_past
|
254
|
+
key = torch.cat((past_key, key), dim=-2)
|
255
|
+
value = torch.cat((past_value, value), dim=-2)
|
256
|
+
|
257
|
+
if use_cache is True:
|
258
|
+
present = (key, value)
|
259
|
+
else:
|
260
|
+
present = None
|
261
|
+
|
262
|
+
# Apply Flash Attention Forward
|
263
|
+
if self.reorder_and_upcast_attn:
|
264
|
+
attn_output, attn_weights = self._upcast_and_reordered_attn(
|
265
|
+
query, key, value, attention_mask, head_mask
|
266
|
+
)
|
267
|
+
else:
|
268
|
+
# Flash Attention forward pass
|
269
|
+
attn_output = self._flash_attention_forward(
|
270
|
+
query,
|
271
|
+
key,
|
272
|
+
value,
|
273
|
+
attention_mask,
|
274
|
+
query.size(-2),
|
275
|
+
self.attn_dropout.p,
|
276
|
+
softmax_scale=None,
|
277
|
+
)
|
278
|
+
|
279
|
+
# Merge heads and project back to hidden size
|
280
|
+
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
281
|
+
attn_output = self.c_proj(attn_output)
|
282
|
+
attn_output = self.resid_dropout(attn_output)
|
283
|
+
|
284
|
+
outputs = (attn_output, present)
|
285
|
+
if output_attentions:
|
286
|
+
outputs += (attn_weights,)
|
287
|
+
|
288
|
+
return outputs
|
289
|
+
|
290
|
+
def _flash_attention_forward(
|
291
|
+
self,
|
292
|
+
query_states,
|
293
|
+
key_states,
|
294
|
+
value_states,
|
295
|
+
attention_mask,
|
296
|
+
query_length,
|
297
|
+
dropout=0.0,
|
298
|
+
softmax_scale=None,
|
299
|
+
):
|
300
|
+
"""
|
301
|
+
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token.
|
302
|
+
|
303
|
+
first unpad the input, then computes the attention scores and pad the final attention scores.
|
304
|
+
Args:
|
305
|
+
query_states (`torch.Tensor`):
|
306
|
+
Input query states to be passed to Flash Attention API
|
307
|
+
key_states (`torch.Tensor`):
|
308
|
+
Input key states to be passed to Flash Attention API
|
309
|
+
value_states (`torch.Tensor`):
|
310
|
+
Input value states to be passed to Flash Attention API
|
311
|
+
attention_mask (`torch.Tensor`):
|
312
|
+
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
313
|
+
position of padding tokens and 1 for the position of non-padding tokens.
|
314
|
+
dropout (`int`, *optional*):
|
315
|
+
Attention dropout
|
316
|
+
softmax_scale (`float`, *optional*):
|
317
|
+
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
318
|
+
"""
|
319
|
+
|
320
|
+
# Flash attention requires the input to have the shape
|
321
|
+
# batch_size x seq_length x head_dim x hidden_dim
|
322
|
+
# therefore we just need to keep the original shape
|
323
|
+
dtype = query_states.dtype
|
324
|
+
query_states = query_states.permute(0, 2, 1, 3).contiguous().to(torch.bfloat16)
|
325
|
+
key_states = key_states.permute(0, 2, 1, 3).contiguous().to(torch.bfloat16)
|
326
|
+
value_states = value_states.permute(0, 2, 1, 3).contiguous().to(torch.bfloat16)
|
327
|
+
|
328
|
+
# Contains at least one padding token in the sequence
|
329
|
+
if attention_mask is not None:
|
330
|
+
batch_size = query_states.shape[0]
|
331
|
+
|
332
|
+
(
|
333
|
+
query_states,
|
334
|
+
key_states,
|
335
|
+
value_states,
|
336
|
+
indices_q,
|
337
|
+
cu_seq_lens,
|
338
|
+
max_seq_lens,
|
339
|
+
) = self._upad_input(
|
340
|
+
query_states, key_states, value_states, attention_mask, query_length
|
341
|
+
)
|
342
|
+
|
343
|
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
344
|
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
345
|
+
|
346
|
+
attn_output_unpad = flash_attn_varlen_func(
|
347
|
+
query_states,
|
348
|
+
key_states,
|
349
|
+
value_states,
|
350
|
+
cu_seqlens_q=cu_seqlens_q,
|
351
|
+
cu_seqlens_k=cu_seqlens_k,
|
352
|
+
max_seqlen_q=max_seqlen_in_batch_q,
|
353
|
+
max_seqlen_k=max_seqlen_in_batch_k,
|
354
|
+
dropout_p=dropout,
|
355
|
+
softmax_scale=softmax_scale,
|
356
|
+
causal=True,
|
357
|
+
)
|
358
|
+
# (batch, seq_length, n_heads, head_dim)
|
359
|
+
attn_output = pad_input(
|
360
|
+
attn_output_unpad, indices_q, batch_size, query_length
|
361
|
+
)
|
362
|
+
else:
|
363
|
+
attn_output = flash_attn_func(
|
364
|
+
query_states,
|
365
|
+
key_states,
|
366
|
+
value_states,
|
367
|
+
dropout,
|
368
|
+
softmax_scale=softmax_scale,
|
369
|
+
causal=self.is_causal,
|
370
|
+
)
|
371
|
+
# re-order the tensor back to (batch, n_heads, seq_length, head_dim)
|
372
|
+
return attn_output.permute(0, 2, 1, 3).contiguous().to(dtype)
|
373
|
+
|
374
|
+
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
|
375
|
+
def _upad_input(
|
376
|
+
self, query_layer, key_layer, value_layer, attention_mask, query_length
|
377
|
+
):
|
378
|
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
379
|
+
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
380
|
+
|
381
|
+
key_layer = index_first_axis(
|
382
|
+
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
383
|
+
indices_k,
|
384
|
+
)
|
385
|
+
value_layer = index_first_axis(
|
386
|
+
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
387
|
+
indices_k,
|
388
|
+
)
|
389
|
+
if query_length == kv_seq_len:
|
390
|
+
query_layer = index_first_axis(
|
391
|
+
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
|
392
|
+
indices_k,
|
393
|
+
)
|
394
|
+
cu_seqlens_q = cu_seqlens_k
|
395
|
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
396
|
+
indices_q = indices_k
|
397
|
+
elif query_length == 1:
|
398
|
+
max_seqlen_in_batch_q = 1
|
399
|
+
cu_seqlens_q = torch.arange(
|
400
|
+
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
401
|
+
) # There is a memcpy here, that is very bad.
|
402
|
+
indices_q = cu_seqlens_q[:-1]
|
403
|
+
query_layer = query_layer.squeeze(1)
|
404
|
+
else:
|
405
|
+
# The -q_len: slice assumes left padding.
|
406
|
+
attention_mask = attention_mask[:, -query_length:]
|
407
|
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
|
408
|
+
query_layer, attention_mask
|
409
|
+
)
|
410
|
+
|
411
|
+
return (
|
412
|
+
query_layer,
|
413
|
+
key_layer,
|
414
|
+
value_layer,
|
415
|
+
indices_q,
|
416
|
+
(cu_seqlens_q, cu_seqlens_k),
|
417
|
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
418
|
+
)
|
419
|
+
|
420
|
+
|
421
|
+
class LlamaMLP(nn.Module):
|
422
|
+
def __init__(self, intermediate_size, config):
|
423
|
+
super().__init__()
|
424
|
+
self.config = config
|
425
|
+
self.hidden_size = config.hidden_size
|
426
|
+
self.intermediate_size = intermediate_size
|
427
|
+
self.gate_proj = nn.Linear(
|
428
|
+
self.hidden_size, self.intermediate_size, bias=config.mlp_bias
|
429
|
+
)
|
430
|
+
self.up_proj = nn.Linear(
|
431
|
+
self.hidden_size, self.intermediate_size, bias=config.mlp_bias
|
432
|
+
)
|
433
|
+
self.down_proj = nn.Linear(
|
434
|
+
self.intermediate_size, self.hidden_size, bias=config.mlp_bias
|
435
|
+
)
|
436
|
+
self.act_fn = ACT2FN[config.activation_function]
|
437
|
+
|
438
|
+
def forward(self, x: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
|
439
|
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
440
|
+
return down_proj
|
441
|
+
|
442
|
+
|
443
|
+
class GPT2MLP(nn.Module):
|
444
|
+
def __init__(self, intermediate_size, config):
|
445
|
+
super().__init__()
|
446
|
+
embed_dim = config.hidden_size
|
447
|
+
self.c_fc = Conv1D(intermediate_size, embed_dim)
|
448
|
+
self.c_proj = Conv1D(embed_dim, intermediate_size)
|
449
|
+
self.act = ACT2FN[config.activation_function]
|
450
|
+
self.dropout = nn.Dropout(config.resid_pdrop)
|
451
|
+
|
452
|
+
def forward(
|
453
|
+
self, hidden_states: Optional[Tuple[torch.FloatTensor]]
|
454
|
+
) -> torch.FloatTensor:
|
455
|
+
hidden_states = self.c_fc(hidden_states)
|
456
|
+
hidden_states = self.act(hidden_states)
|
457
|
+
hidden_states = self.c_proj(hidden_states)
|
458
|
+
hidden_states = self.dropout(hidden_states)
|
459
|
+
return hidden_states
|
460
|
+
|
461
|
+
|
462
|
+
class GPT2Block(nn.Module):
|
463
|
+
def __init__(self, config, layer_idx=None):
|
464
|
+
super().__init__()
|
465
|
+
hidden_size = config.hidden_size
|
466
|
+
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
467
|
+
attention_class = (
|
468
|
+
GPT2FlashAttention
|
469
|
+
if getattr(config, "_attn_implementation", "eager") == "flash_attention_2"
|
470
|
+
else GPT2AttentionRoPE
|
471
|
+
)
|
472
|
+
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
473
|
+
self.attn = attention_class(
|
474
|
+
config=config, layer_idx=layer_idx, apply_rotary=config.apply_rotary
|
475
|
+
)
|
476
|
+
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
477
|
+
|
478
|
+
if config.add_cross_attention:
|
479
|
+
self.crossattention = attention_class(
|
480
|
+
config=config, is_cross_attention=True, layer_idx=layer_idx
|
481
|
+
)
|
482
|
+
self.ln_cross_attn = nn.LayerNorm(
|
483
|
+
hidden_size, eps=config.layer_norm_epsilon
|
484
|
+
)
|
485
|
+
|
486
|
+
decoder_mlp_function = getattr(config, "decoder_mlp", "GPT2MLP")
|
487
|
+
if decoder_mlp_function == "GPT2MLP":
|
488
|
+
self.mlp = GPT2MLP(inner_dim, config)
|
489
|
+
elif getattr(config, "decoder_mlp", "GPT2Block") == "LlamaMLP":
|
490
|
+
self.mlp = LlamaMLP(inner_dim, config)
|
491
|
+
else:
|
492
|
+
raise RuntimeError("You must set decoder_mlp to one of (GPT2MLP, LlamaMLP)")
|
493
|
+
|
494
|
+
def forward(
|
495
|
+
self,
|
496
|
+
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
497
|
+
position_ids: Optional[Tuple[torch.FloatTensor]] = None,
|
498
|
+
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
499
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
500
|
+
head_mask: Optional[torch.FloatTensor] = None,
|
501
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
502
|
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
503
|
+
use_cache: Optional[bool] = False,
|
504
|
+
output_attentions: Optional[bool] = False,
|
505
|
+
) -> Union[
|
506
|
+
Tuple[torch.Tensor],
|
507
|
+
Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]],
|
508
|
+
]:
|
509
|
+
residual = hidden_states
|
510
|
+
hidden_states = self.ln_1(hidden_states)
|
511
|
+
attn_outputs = self.attn(
|
512
|
+
hidden_states,
|
513
|
+
position_ids=position_ids,
|
514
|
+
layer_past=layer_past,
|
515
|
+
attention_mask=attention_mask,
|
516
|
+
head_mask=head_mask,
|
517
|
+
use_cache=use_cache,
|
518
|
+
output_attentions=output_attentions,
|
519
|
+
)
|
520
|
+
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
521
|
+
outputs = attn_outputs[1:]
|
522
|
+
# residual connection
|
523
|
+
hidden_states = attn_output + residual
|
524
|
+
|
525
|
+
if encoder_hidden_states is not None:
|
526
|
+
# add one self-attention block for cross-attention
|
527
|
+
if not hasattr(self, "crossattention"):
|
528
|
+
raise ValueError(
|
529
|
+
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
|
530
|
+
"cross-attention layers by setting `config.add_cross_attention=True`"
|
531
|
+
)
|
532
|
+
residual = hidden_states
|
533
|
+
hidden_states = self.ln_cross_attn(hidden_states)
|
534
|
+
cross_attn_outputs = self.crossattention(
|
535
|
+
hidden_states,
|
536
|
+
attention_mask=attention_mask,
|
537
|
+
head_mask=head_mask,
|
538
|
+
encoder_hidden_states=encoder_hidden_states,
|
539
|
+
encoder_attention_mask=encoder_attention_mask,
|
540
|
+
output_attentions=output_attentions,
|
541
|
+
)
|
542
|
+
attn_output = cross_attn_outputs[0]
|
543
|
+
# residual connection
|
544
|
+
hidden_states = residual + attn_output
|
545
|
+
outputs = (
|
546
|
+
outputs + cross_attn_outputs[2:]
|
547
|
+
) # add cross attentions if we output attention weights
|
548
|
+
|
549
|
+
residual = hidden_states
|
550
|
+
hidden_states = self.ln_2(hidden_states)
|
551
|
+
feed_forward_hidden_states = self.mlp(hidden_states)
|
552
|
+
# residual connection
|
553
|
+
hidden_states = residual + feed_forward_hidden_states
|
554
|
+
|
555
|
+
if use_cache:
|
556
|
+
outputs = (hidden_states,) + outputs
|
557
|
+
else:
|
558
|
+
outputs = (hidden_states,) + outputs[1:]
|
559
|
+
|
560
|
+
return outputs # hidden_states, present, (attentions, cross_attentions)
|