cehrgpt 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/htn_treatment_pathway.py +546 -0
- cehrgpt/analysis/treatment_pathway/__init__.py +0 -0
- cehrgpt/analysis/treatment_pathway/depression_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/diabetes_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/htn_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/treatment_pathway.py +631 -0
- cehrgpt/data/cehrgpt_data_processor.py +549 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +4 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +286 -629
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +60 -14
- cehrgpt/generation/cehrgpt_conditional_generation.py +316 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +35 -15
- cehrgpt/generation/omop_converter_batch.py +11 -4
- cehrgpt/gpt_utils.py +73 -3
- cehrgpt/models/activations.py +27 -0
- cehrgpt/models/config.py +6 -2
- cehrgpt/models/gpt2.py +560 -0
- cehrgpt/models/hf_cehrgpt.py +193 -459
- cehrgpt/models/tokenization_hf_cehrgpt.py +380 -50
- cehrgpt/omop/ontology.py +154 -0
- cehrgpt/runners/data_utils.py +17 -6
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +33 -79
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +48 -44
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +58 -34
- cehrgpt/runners/hyperparameter_search_util.py +180 -69
- cehrgpt/runners/sample_packing_trainer.py +11 -2
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +27 -31
- cehrgpt-0.1.3.dist-info/METADATA +238 -0
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/RECORD +33 -22
- cehrgpt-0.1.1.dist-info/METADATA +0 -115
- /cehrgpt/tools/{merge_synthetic_real_dataasets.py → merge_synthetic_real_datasets.py} +0 -0
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/WHEEL +0 -0
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/top_level.txt +0 -0
cehrgpt/models/hf_cehrgpt.py
CHANGED
@@ -6,9 +6,8 @@ import numpy as np
|
|
6
6
|
import torch
|
7
7
|
import torch.nn.functional as f
|
8
8
|
from torch import nn
|
9
|
-
from torch.distributions import
|
9
|
+
from torch.distributions import Gamma
|
10
10
|
from torch.nn import CrossEntropyLoss
|
11
|
-
from torch.nn import functional as F
|
12
11
|
from transformers import PreTrainedModel
|
13
12
|
from transformers.activations import gelu_new
|
14
13
|
from transformers.generation.logits_process import LogitsProcessorList
|
@@ -18,16 +17,20 @@ from transformers.generation.stopping_criteria import (
|
|
18
17
|
)
|
19
18
|
from transformers.generation.streamers import BaseStreamer
|
20
19
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
21
|
-
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block
|
22
20
|
from transformers.pytorch_utils import Conv1D
|
23
|
-
from transformers.utils import
|
24
|
-
is_accelerate_available,
|
25
|
-
is_flash_attn_2_available,
|
26
|
-
logging,
|
27
|
-
)
|
21
|
+
from transformers.utils import is_flash_attn_2_available, logging
|
28
22
|
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
|
29
23
|
|
24
|
+
from cehrgpt.gpt_utils import (
|
25
|
+
construct_age_sequence,
|
26
|
+
encode_demographics,
|
27
|
+
extract_time_interval_in_days,
|
28
|
+
is_att_token,
|
29
|
+
multiple_of_10,
|
30
|
+
)
|
31
|
+
from cehrgpt.models.activations import RMSNorm
|
30
32
|
from cehrgpt.models.config import CEHRGPTConfig
|
33
|
+
from cehrgpt.models.gpt2 import GPT2Block, is_sample_pack
|
31
34
|
from cehrgpt.models.hf_modeling_outputs import (
|
32
35
|
CehrGptCausalLMOutput,
|
33
36
|
CehrGptGenerateDecoderOnlyOutput,
|
@@ -35,13 +38,6 @@ from cehrgpt.models.hf_modeling_outputs import (
|
|
35
38
|
CehrGptSequenceClassifierOutput,
|
36
39
|
)
|
37
40
|
|
38
|
-
if is_flash_attn_2_available():
|
39
|
-
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
40
|
-
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
41
|
-
|
42
|
-
if is_accelerate_available():
|
43
|
-
from accelerate.hooks import add_hook_to_module
|
44
|
-
|
45
41
|
logger = logging.get_logger(__name__)
|
46
42
|
|
47
43
|
|
@@ -50,7 +46,7 @@ def extract_features_from_packed_sequence(
|
|
50
46
|
attention_mask: torch.Tensor,
|
51
47
|
) -> torch.Tensor:
|
52
48
|
max_index = attention_mask.nonzero(as_tuple=False).flatten()[-1]
|
53
|
-
padded_attention_mask =
|
49
|
+
padded_attention_mask = f.pad(attention_mask[:, : max_index + 1], (0, 1))
|
54
50
|
feature_indices = torch.nonzero(padded_attention_mask == 0)[:, 1] - 1
|
55
51
|
return hidden_state[:, feature_indices]
|
56
52
|
|
@@ -81,313 +77,41 @@ def create_sample_packing_attention_mask(attention_mask: torch.Tensor) -> torch.
|
|
81
77
|
return attn_matrix
|
82
78
|
|
83
79
|
|
84
|
-
|
85
|
-
|
86
|
-
Determines whether any sequence in the batch is likely sample-packed.
|
87
|
-
|
88
|
-
A sample-packed sequence is one where there are non-padding (1) tokens
|
89
|
-
after a padding (0) token, indicating multiple sequences packed together
|
90
|
-
with padding as a separator.
|
91
|
-
|
92
|
-
Args:
|
93
|
-
attention_mask (torch.Tensor): A tensor of shape (batch_size, seq_len)
|
94
|
-
where 1 indicates a real token and 0 indicates padding.
|
95
|
-
|
96
|
-
Returns:
|
97
|
-
bool: True if any sample in the batch is sample-packed, False otherwise.
|
98
|
-
"""
|
99
|
-
|
100
|
-
# If the attention_maks is left padded, we will flip it so we can use the same logic below
|
101
|
-
if (attention_mask[:, 0] == 0).any():
|
102
|
-
attention_mask = attention_mask.flip(dims=[1])
|
103
|
-
|
104
|
-
nonzero_counts = attention_mask.sum(dim=1)
|
105
|
-
max_token_positions = torch.argmax(attention_mask.flip(dims=[1]), dim=1)
|
106
|
-
max_indices = attention_mask.shape[1] - 1 - max_token_positions
|
107
|
-
return torch.any(nonzero_counts < (max_indices + 1)).item()
|
108
|
-
|
109
|
-
|
110
|
-
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
111
|
-
def _get_unpad_data(attention_mask):
|
112
|
-
# This infers sample packing
|
113
|
-
if is_sample_pack(attention_mask):
|
114
|
-
# Assume input: attention_mask shape = (batch, seq_len)
|
115
|
-
attention_mask = attention_mask.flatten() # shape: (seq_len,)
|
116
|
-
|
117
|
-
# Compute max_index of the last non-zero element
|
118
|
-
nonzero = torch.nonzero(attention_mask, as_tuple=False).flatten()
|
119
|
-
max_index = nonzero[-1].item()
|
120
|
-
|
121
|
-
# Pad the truncated attention mask
|
122
|
-
padded_attention_mask = F.pad(attention_mask[: max_index + 1], (0, 1), value=0)
|
123
|
-
|
124
|
-
# Indices of all tokens
|
125
|
-
indices = torch.nonzero(attention_mask, as_tuple=False).flatten()
|
126
|
-
|
127
|
-
# Find where 0s occur (segment boundaries)
|
128
|
-
cumsum_seqlens_in_batch = torch.cumsum(padded_attention_mask, dim=0)[
|
129
|
-
padded_attention_mask == 0
|
130
|
-
]
|
131
|
-
|
132
|
-
# Compute seqlens per segment
|
133
|
-
seqlens_in_batch = (
|
134
|
-
cumsum_seqlens_in_batch
|
135
|
-
- F.pad(cumsum_seqlens_in_batch, (1, 0), value=0)[:-1]
|
136
|
-
).to(torch.int)
|
137
|
-
|
138
|
-
max_seqlen_in_batch = (
|
139
|
-
seqlens_in_batch.max().item() if seqlens_in_batch.numel() > 0 else 0
|
140
|
-
)
|
141
|
-
cu_seqlens = F.pad(cumsum_seqlens_in_batch, (1, 0)).to(torch.int)
|
142
|
-
else:
|
143
|
-
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
144
|
-
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
145
|
-
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
146
|
-
cu_seqlens = F.pad(
|
147
|
-
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
148
|
-
)
|
149
|
-
|
150
|
-
return (
|
151
|
-
indices,
|
152
|
-
cu_seqlens,
|
153
|
-
max_seqlen_in_batch,
|
154
|
-
)
|
155
|
-
|
156
|
-
|
157
|
-
class GPT2FlashAttention(GPT2Attention):
|
158
|
-
"""
|
159
|
-
GPT2FlashAttention inherits from `GPT2Attention`.
|
160
|
-
|
161
|
-
The primary change is in the forward pass, where it correctly
|
162
|
-
calls the public API of flash attention and handles padding tokens.
|
163
|
-
"""
|
164
|
-
|
165
|
-
def forward(
|
166
|
-
self,
|
167
|
-
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
168
|
-
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
169
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
170
|
-
head_mask: Optional[torch.FloatTensor] = None,
|
171
|
-
encoder_hidden_states: Optional[torch.Tensor] = None,
|
172
|
-
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
173
|
-
use_cache: Optional[bool] = False,
|
174
|
-
output_attentions: Optional[bool] = False,
|
175
|
-
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
176
|
-
# Prepare query, key, and value
|
177
|
-
if encoder_hidden_states is not None:
|
178
|
-
if not hasattr(self, "q_attn"):
|
179
|
-
raise ValueError(
|
180
|
-
"If class is used as cross attention, the weights `q_attn` have to be defined. "
|
181
|
-
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
|
182
|
-
)
|
183
|
-
|
184
|
-
query = self.q_attn(hidden_states)
|
185
|
-
key, value = self.c_attn(encoder_hidden_states).split(
|
186
|
-
self.split_size, dim=2
|
187
|
-
)
|
188
|
-
attention_mask = encoder_attention_mask
|
189
|
-
else:
|
190
|
-
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
191
|
-
|
192
|
-
query = self._split_heads(query, self.num_heads, self.head_dim)
|
193
|
-
key = self._split_heads(key, self.num_heads, self.head_dim)
|
194
|
-
value = self._split_heads(value, self.num_heads, self.head_dim)
|
195
|
-
|
196
|
-
if layer_past is not None:
|
197
|
-
past_key, past_value = layer_past
|
198
|
-
key = torch.cat((past_key, key), dim=-2)
|
199
|
-
value = torch.cat((past_value, value), dim=-2)
|
200
|
-
|
201
|
-
if use_cache is True:
|
202
|
-
present = (key, value)
|
203
|
-
else:
|
204
|
-
present = None
|
205
|
-
|
206
|
-
# Apply Flash Attention Forward
|
207
|
-
if self.reorder_and_upcast_attn:
|
208
|
-
attn_output, attn_weights = self._upcast_and_reordered_attn(
|
209
|
-
query, key, value, attention_mask, head_mask
|
210
|
-
)
|
211
|
-
else:
|
212
|
-
# Flash Attention forward pass
|
213
|
-
attn_output = self._flash_attention_forward(
|
214
|
-
query,
|
215
|
-
key,
|
216
|
-
value,
|
217
|
-
attention_mask,
|
218
|
-
query.size(-2),
|
219
|
-
self.attn_dropout.p,
|
220
|
-
softmax_scale=None,
|
221
|
-
)
|
222
|
-
|
223
|
-
# Merge heads and project back to hidden size
|
224
|
-
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
225
|
-
attn_output = self.c_proj(attn_output)
|
226
|
-
attn_output = self.resid_dropout(attn_output)
|
227
|
-
|
228
|
-
outputs = (attn_output, present)
|
229
|
-
if output_attentions:
|
230
|
-
outputs += (attn_weights,)
|
231
|
-
|
232
|
-
return outputs
|
233
|
-
|
234
|
-
def _flash_attention_forward(
|
80
|
+
class MotorTaskHead(nn.Module):
|
81
|
+
def __init__(
|
235
82
|
self,
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
query_length,
|
241
|
-
dropout=0.0,
|
242
|
-
softmax_scale=None,
|
83
|
+
input_dim,
|
84
|
+
motor_tte_vocab_size,
|
85
|
+
motor_num_time_pieces,
|
86
|
+
eps=1e-6,
|
243
87
|
):
|
244
|
-
"""
|
245
|
-
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token.
|
246
|
-
|
247
|
-
first unpad the input, then computes the attention scores and pad the final attention scores.
|
248
|
-
Args:
|
249
|
-
query_states (`torch.Tensor`):
|
250
|
-
Input query states to be passed to Flash Attention API
|
251
|
-
key_states (`torch.Tensor`):
|
252
|
-
Input key states to be passed to Flash Attention API
|
253
|
-
value_states (`torch.Tensor`):
|
254
|
-
Input value states to be passed to Flash Attention API
|
255
|
-
attention_mask (`torch.Tensor`):
|
256
|
-
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
257
|
-
position of padding tokens and 1 for the position of non-padding tokens.
|
258
|
-
dropout (`int`, *optional*):
|
259
|
-
Attention dropout
|
260
|
-
softmax_scale (`float`, *optional*):
|
261
|
-
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
262
|
-
"""
|
263
|
-
|
264
|
-
# Flash attention requires the input to have the shape
|
265
|
-
# batch_size x seq_length x head_dim x hidden_dim
|
266
|
-
# therefore we just need to keep the original shape
|
267
|
-
dtype = query_states.dtype
|
268
|
-
query_states = query_states.permute(0, 2, 1, 3).contiguous().to(torch.bfloat16)
|
269
|
-
key_states = key_states.permute(0, 2, 1, 3).contiguous().to(torch.bfloat16)
|
270
|
-
value_states = value_states.permute(0, 2, 1, 3).contiguous().to(torch.bfloat16)
|
271
|
-
|
272
|
-
# Contains at least one padding token in the sequence
|
273
|
-
if attention_mask is not None:
|
274
|
-
batch_size = query_states.shape[0]
|
275
|
-
|
276
|
-
(
|
277
|
-
query_states,
|
278
|
-
key_states,
|
279
|
-
value_states,
|
280
|
-
indices_q,
|
281
|
-
cu_seq_lens,
|
282
|
-
max_seq_lens,
|
283
|
-
) = self._upad_input(
|
284
|
-
query_states, key_states, value_states, attention_mask, query_length
|
285
|
-
)
|
286
|
-
|
287
|
-
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
288
|
-
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
289
|
-
|
290
|
-
attn_output_unpad = flash_attn_varlen_func(
|
291
|
-
query_states,
|
292
|
-
key_states,
|
293
|
-
value_states,
|
294
|
-
cu_seqlens_q=cu_seqlens_q,
|
295
|
-
cu_seqlens_k=cu_seqlens_k,
|
296
|
-
max_seqlen_q=max_seqlen_in_batch_q,
|
297
|
-
max_seqlen_k=max_seqlen_in_batch_k,
|
298
|
-
dropout_p=dropout,
|
299
|
-
softmax_scale=softmax_scale,
|
300
|
-
causal=True,
|
301
|
-
)
|
302
|
-
# (batch, seq_length, n_heads, head_dim)
|
303
|
-
attn_output = pad_input(
|
304
|
-
attn_output_unpad, indices_q, batch_size, query_length
|
305
|
-
)
|
306
|
-
else:
|
307
|
-
attn_output = flash_attn_func(
|
308
|
-
query_states,
|
309
|
-
key_states,
|
310
|
-
value_states,
|
311
|
-
dropout,
|
312
|
-
softmax_scale=softmax_scale,
|
313
|
-
causal=self.is_causal,
|
314
|
-
)
|
315
|
-
# re-order the tensor back to (batch, n_heads, seq_length, head_dim)
|
316
|
-
return attn_output.permute(0, 2, 1, 3).contiguous().to(dtype)
|
317
|
-
|
318
|
-
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
|
319
|
-
def _upad_input(
|
320
|
-
self, query_layer, key_layer, value_layer, attention_mask, query_length
|
321
|
-
):
|
322
|
-
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
323
|
-
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
324
|
-
|
325
|
-
key_layer = index_first_axis(
|
326
|
-
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
327
|
-
indices_k,
|
328
|
-
)
|
329
|
-
value_layer = index_first_axis(
|
330
|
-
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
331
|
-
indices_k,
|
332
|
-
)
|
333
|
-
if query_length == kv_seq_len:
|
334
|
-
query_layer = index_first_axis(
|
335
|
-
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
|
336
|
-
indices_k,
|
337
|
-
)
|
338
|
-
cu_seqlens_q = cu_seqlens_k
|
339
|
-
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
340
|
-
indices_q = indices_k
|
341
|
-
elif query_length == 1:
|
342
|
-
max_seqlen_in_batch_q = 1
|
343
|
-
cu_seqlens_q = torch.arange(
|
344
|
-
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
345
|
-
) # There is a memcpy here, that is very bad.
|
346
|
-
indices_q = cu_seqlens_q[:-1]
|
347
|
-
query_layer = query_layer.squeeze(1)
|
348
|
-
else:
|
349
|
-
# The -q_len: slice assumes left padding.
|
350
|
-
attention_mask = attention_mask[:, -query_length:]
|
351
|
-
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
|
352
|
-
query_layer, attention_mask
|
353
|
-
)
|
354
|
-
|
355
|
-
return (
|
356
|
-
query_layer,
|
357
|
-
key_layer,
|
358
|
-
value_layer,
|
359
|
-
indices_q,
|
360
|
-
(cu_seqlens_q, cu_seqlens_k),
|
361
|
-
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
362
|
-
)
|
363
|
-
|
364
|
-
|
365
|
-
class MotorTaskHead(nn.Module):
|
366
|
-
def __init__(self, input_dim, motor_tte_vocab_size, motor_num_time_pieces):
|
367
88
|
super(MotorTaskHead, self).__init__()
|
368
89
|
self.input_dim = input_dim
|
369
90
|
self.motor_tte_vocab_size = motor_tte_vocab_size
|
370
91
|
self.motor_num_time_pieces = motor_num_time_pieces
|
371
|
-
self.
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
),
|
92
|
+
self.final_layer = nn.Linear(input_dim, input_dim * motor_num_time_pieces)
|
93
|
+
self.norm = RMSNorm(input_dim, eps)
|
94
|
+
self.task_layer = nn.Linear(input_dim, motor_tte_vocab_size)
|
95
|
+
self.task_time_bias = nn.Parameter(
|
96
|
+
torch.zeros(1, self.motor_num_time_pieces, motor_tte_vocab_size)
|
377
97
|
)
|
378
98
|
|
379
99
|
def forward(self, x):
|
380
100
|
# Ensure scale is positive
|
381
101
|
length = x.shape[0]
|
382
102
|
# (num_visits_in_batch, motor_tte_vocab_size * motor_num_time_pieces)
|
383
|
-
|
384
|
-
|
385
|
-
if torch.isnan(lambda_p).any():
|
386
|
-
logger.warning(f"NaN values found in scale_param. x: {x}")
|
387
|
-
# (num_visits_in_batch, motor_num_time_pieces, motor_tte_vocab_size,)
|
388
|
-
return lambda_p.view(
|
389
|
-
length, self.motor_num_time_pieces, self.motor_tte_vocab_size
|
103
|
+
x = self.final_layer(x).reshape(
|
104
|
+
length, self.motor_num_time_pieces, self.input_dim
|
390
105
|
)
|
106
|
+
x = self.norm(x)
|
107
|
+
x = self.task_layer(x) + self.task_time_bias
|
108
|
+
# lambda_p = f.softplus(x)
|
109
|
+
|
110
|
+
# # Check for NaN values
|
111
|
+
# if torch.isnan(lambda_p).any():
|
112
|
+
# logger.warning(f"NaN values found in scale_param. x: {x}")
|
113
|
+
# # (num_visits_in_batch, motor_num_time_pieces, motor_tte_vocab_size,)
|
114
|
+
return x
|
391
115
|
|
392
116
|
|
393
117
|
class VisitTimeToEventHead(nn.Module):
|
@@ -723,7 +447,6 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
723
447
|
def __init__(self, config: CEHRGPTConfig):
|
724
448
|
super().__init__(config)
|
725
449
|
|
726
|
-
self.exclude_position_ids = config.exclude_position_ids
|
727
450
|
self.include_values = config.include_values
|
728
451
|
self.include_ttv_prediction = config.include_ttv_prediction
|
729
452
|
self.embed_dim = config.hidden_size
|
@@ -734,8 +457,6 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
734
457
|
self.pretrained_wte = None
|
735
458
|
|
736
459
|
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
|
737
|
-
if not self.exclude_position_ids:
|
738
|
-
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
739
460
|
if self.include_values:
|
740
461
|
self.vte = nn.Embedding(config.value_vocab_size, self.embed_dim)
|
741
462
|
self.concept_value_transformation_layer = ConceptValueTransformationLayer(
|
@@ -746,9 +467,7 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
746
467
|
gpt_blocks = []
|
747
468
|
for i in range(config.num_hidden_layers):
|
748
469
|
gpt_block = GPT2Block(config, layer_idx=i)
|
749
|
-
|
750
|
-
gpt_block.attn = GPT2FlashAttention(config, layer_idx=i)
|
751
|
-
gpt_block.is_causal = True
|
470
|
+
gpt_block.is_causal = True
|
752
471
|
gpt_blocks.append(gpt_block)
|
753
472
|
self.h = nn.ModuleList(gpt_blocks)
|
754
473
|
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
@@ -769,10 +488,6 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
769
488
|
)
|
770
489
|
self.update_attn_bias(self.config.sample_packing_max_positions)
|
771
490
|
|
772
|
-
def enable_position_embeddings(self):
|
773
|
-
self.wpe = nn.Embedding(self.config.max_position_embeddings, self.embed_dim)
|
774
|
-
self.config.exclude_position_ids = False
|
775
|
-
|
776
491
|
def initialize_pretrained_embeddings(self):
|
777
492
|
layers = [
|
778
493
|
nn.Embedding(self.config.vocab_size, self.config.pretrained_embedding_dim),
|
@@ -815,8 +530,6 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
815
530
|
self.wte = self.wte.to(self.first_device)
|
816
531
|
if self.config.use_pretrained_embeddings:
|
817
532
|
self.pretrained_wte = self.pretrained_wte.to(self.first_device)
|
818
|
-
if not self.exclude_position_ids:
|
819
|
-
self.wpe = self.wpe.to(self.first_device)
|
820
533
|
if self.include_values:
|
821
534
|
self.vte = self.vte.to(self.first_device)
|
822
535
|
self.concept_value_transformation_layer = (
|
@@ -842,8 +555,6 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
842
555
|
self.wte = self.wte.to("cpu")
|
843
556
|
if self.config.use_pretrained_embeddings:
|
844
557
|
self.pretrained_wte = self.pretrained_wte.to("cpu")
|
845
|
-
if not self.exclude_position_ids:
|
846
|
-
self.wpe = self.wpe.to("cpu")
|
847
558
|
self.vte = self.vte.to("cpu")
|
848
559
|
self.concept_value_transformation_layer = (
|
849
560
|
self.concept_value_transformation_layer.to("cpu")
|
@@ -871,8 +582,6 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
871
582
|
def get_position_embeddings(
|
872
583
|
self,
|
873
584
|
) -> Optional[Union[nn.Embedding, Tuple[nn.Embedding]]]:
|
874
|
-
if not self.exclude_position_ids:
|
875
|
-
return self.wpe
|
876
585
|
return None
|
877
586
|
|
878
587
|
def set_position_embeddings(self, new_embeddings: nn.Embedding):
|
@@ -946,24 +655,12 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
946
655
|
# Convert list back to torch.Size if needed
|
947
656
|
input_shape = torch.Size(shape_list)
|
948
657
|
|
949
|
-
|
658
|
+
input_ids.device
|
950
659
|
|
951
660
|
if past_key_values is None:
|
952
|
-
past_length = 0
|
953
661
|
past_key_values = tuple([None] * len(self.h))
|
954
662
|
else:
|
955
|
-
|
956
|
-
|
957
|
-
# This is normally called during training or fine-tuning.
|
958
|
-
# While the generation logic will handle position_ids in the sampling logic
|
959
|
-
if position_ids is None and not self.exclude_position_ids:
|
960
|
-
position_ids = torch.arange(
|
961
|
-
past_length,
|
962
|
-
input_shape[-1] + past_length,
|
963
|
-
dtype=torch.long,
|
964
|
-
device=device,
|
965
|
-
)
|
966
|
-
position_ids = position_ids.unsqueeze(0)
|
663
|
+
past_key_values[0][0].size(-2)
|
967
664
|
|
968
665
|
# GPT2Attention mask.
|
969
666
|
if attention_mask is not None:
|
@@ -1046,10 +743,19 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
1046
743
|
]
|
1047
744
|
if random_vectors is None:
|
1048
745
|
random_vectors = torch.rand_like(input_embeddings[:, :1])
|
746
|
+
|
1049
747
|
input_embeddings = torch.concat(
|
1050
748
|
[demographic_embeddings, random_vectors, medical_event_embeddings],
|
1051
749
|
dim=1,
|
1052
750
|
)
|
751
|
+
position_ids = torch.concat(
|
752
|
+
[
|
753
|
+
position_ids[:, : self.config.demographics_size],
|
754
|
+
position_ids[:, :1],
|
755
|
+
position_ids[:, self.config.demographics_size :],
|
756
|
+
],
|
757
|
+
dim=1,
|
758
|
+
)
|
1053
759
|
|
1054
760
|
if self.include_values:
|
1055
761
|
if (
|
@@ -1075,13 +781,7 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
1075
781
|
value_embeddings=value_embeddings,
|
1076
782
|
)
|
1077
783
|
|
1078
|
-
|
1079
|
-
position_embeds = self.wpe(position_ids).to(input_embeddings.dtype)
|
1080
|
-
hidden_states = input_embeddings + position_embeds
|
1081
|
-
else:
|
1082
|
-
hidden_states = input_embeddings
|
1083
|
-
|
1084
|
-
hidden_states = self.drop(hidden_states)
|
784
|
+
hidden_states = self.drop(input_embeddings)
|
1085
785
|
|
1086
786
|
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
|
1087
787
|
|
@@ -1109,6 +809,7 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
1109
809
|
attention_mask = attention_mask.to(hidden_states.device)
|
1110
810
|
if isinstance(head_mask, torch.Tensor):
|
1111
811
|
head_mask = head_mask.to(hidden_states.device)
|
812
|
+
|
1112
813
|
if output_hidden_states:
|
1113
814
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
1114
815
|
|
@@ -1116,6 +817,7 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
1116
817
|
outputs = self._gradient_checkpointing_func(
|
1117
818
|
block.__call__,
|
1118
819
|
hidden_states,
|
820
|
+
position_ids,
|
1119
821
|
None,
|
1120
822
|
attention_mask,
|
1121
823
|
head_mask[i],
|
@@ -1127,6 +829,7 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
1127
829
|
else:
|
1128
830
|
outputs = block(
|
1129
831
|
hidden_states,
|
832
|
+
position_ids=position_ids,
|
1130
833
|
layer_past=layer_past,
|
1131
834
|
attention_mask=attention_mask,
|
1132
835
|
head_mask=head_mask[i],
|
@@ -1200,7 +903,9 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1200
903
|
|
1201
904
|
if self.config.include_motor_time_to_event:
|
1202
905
|
self.motor_tte = MotorTaskHead(
|
1203
|
-
config.n_embd,
|
906
|
+
input_dim=config.n_embd,
|
907
|
+
motor_tte_vocab_size=config.motor_tte_vocab_size,
|
908
|
+
motor_num_time_pieces=config.motor_num_time_pieces,
|
1204
909
|
)
|
1205
910
|
|
1206
911
|
# Model parallel
|
@@ -1300,12 +1005,12 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1300
1005
|
def prepare_inputs_for_generation(
|
1301
1006
|
self,
|
1302
1007
|
input_ids,
|
1008
|
+
cehrgpt_tokenizer,
|
1303
1009
|
past_key_values=None,
|
1304
1010
|
inputs_embeds=None,
|
1305
|
-
lab_token_ids=None,
|
1306
1011
|
**kwargs,
|
1307
1012
|
):
|
1308
|
-
|
1013
|
+
ages = kwargs.get("ages")
|
1309
1014
|
# Omit tokens covered by past_key_values
|
1310
1015
|
if past_key_values:
|
1311
1016
|
past_length = past_key_values[0][0].shape[2]
|
@@ -1320,33 +1025,10 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1320
1025
|
remove_prefix_length = input_ids.shape[1] - 1
|
1321
1026
|
|
1322
1027
|
input_ids = input_ids[:, remove_prefix_length:]
|
1028
|
+
ages = ages[:, remove_prefix_length:]
|
1323
1029
|
|
1324
1030
|
attention_mask = kwargs.get("attention_mask", None)
|
1325
|
-
position_ids = kwargs.get("position_ids", None)
|
1326
1031
|
random_vectors = kwargs.get("random_vectors", None)
|
1327
|
-
|
1328
|
-
if attention_mask is not None and position_ids is None:
|
1329
|
-
# create position_ids on the fly for batch generation
|
1330
|
-
position_ids = attention_mask.long().cumsum(-1) - 1
|
1331
|
-
position_ids.masked_fill_(attention_mask == 0, 1)
|
1332
|
-
if past_key_values:
|
1333
|
-
position_ids = position_ids[:, -input_ids.shape[1] :]
|
1334
|
-
|
1335
|
-
# Add one more position for the random vectors
|
1336
|
-
if (
|
1337
|
-
self.cehrgpt.config.causal_sfm
|
1338
|
-
and position_ids.shape[-1] >= self.cehrgpt.config.demographics_size
|
1339
|
-
):
|
1340
|
-
position_ids = torch.concat(
|
1341
|
-
[
|
1342
|
-
position_ids,
|
1343
|
-
torch.max(position_ids, dim=-1, keepdim=True)[0] + 1,
|
1344
|
-
],
|
1345
|
-
dim=-1,
|
1346
|
-
)
|
1347
|
-
else:
|
1348
|
-
position_ids = None
|
1349
|
-
|
1350
1032
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
1351
1033
|
if inputs_embeds is not None and past_key_values is None:
|
1352
1034
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
@@ -1384,7 +1066,7 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1384
1066
|
{
|
1385
1067
|
"past_key_values": past_key_values,
|
1386
1068
|
"use_cache": kwargs.get("use_cache"),
|
1387
|
-
"
|
1069
|
+
"ages": ages,
|
1388
1070
|
"attention_mask": attention_mask,
|
1389
1071
|
"random_vectors": random_vectors,
|
1390
1072
|
}
|
@@ -1394,12 +1076,12 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1394
1076
|
|
1395
1077
|
def motor_nll_loss(
|
1396
1078
|
self,
|
1397
|
-
|
1398
|
-
|
1399
|
-
|
1400
|
-
|
1401
|
-
|
1402
|
-
|
1079
|
+
hidden_states,
|
1080
|
+
motor_tte_times,
|
1081
|
+
motor_tte_event_indicators,
|
1082
|
+
motor_tte_task_indicators,
|
1083
|
+
motor_tte_masks,
|
1084
|
+
motor_end_index,
|
1403
1085
|
):
|
1404
1086
|
"""
|
1405
1087
|
Computes the negative log-likelihood (NLL) loss using the LogNormal distribution.
|
@@ -1407,58 +1089,62 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1407
1089
|
for modeling time-to-event data at each visit.
|
1408
1090
|
|
1409
1091
|
Args:
|
1410
|
-
|
1411
|
-
|
1412
|
-
|
1413
|
-
|
1414
|
-
|
1415
|
-
|
1416
|
-
|
1092
|
+
hidden_states (Tensor): Hidden representations for sequence tokens [num_of_concepts, hidden_dim].
|
1093
|
+
motor_tte_times (Tensor): Raw time-to-event durations [B, T, motor_vocab_size] (flattened).
|
1094
|
+
motor_tte_task_indicators: (Tensor): Bool indicators (True if included, False if not included).
|
1095
|
+
motor_tte_event_indicators (Tensor): Binary indicators (1 if censored, 0 if event occurred).
|
1096
|
+
motor_tte_masks (Tensor): Binary indicators whether the prediction should be masked
|
1097
|
+
(1 if not masked, 0 if masked).
|
1098
|
+
motor_end_index (Tensor): Tensor indicating the number of valid [VE] tokens in the batch.
|
1417
1099
|
|
1418
1100
|
Returns:
|
1419
1101
|
Tensor: Scalar loss value (mean negative log-likelihood).
|
1420
1102
|
"""
|
1421
|
-
|
1422
|
-
|
1423
|
-
(-1, self.config.motor_num_time_pieces, self.config.motor_tte_vocab_size)
|
1424
|
-
)[:batch_motor_end_index].clamp(min=1e-3)
|
1425
|
-
motor_event_indicators = motor_event_indicators.reshape(
|
1103
|
+
motor_end_index = motor_end_index.sum().item()
|
1104
|
+
motor_tte_times = motor_tte_times.view(
|
1426
1105
|
(-1, self.config.motor_num_time_pieces, self.config.motor_tte_vocab_size)
|
1427
|
-
)[:
|
1428
|
-
|
1429
|
-
:batch_motor_end_index
|
1430
|
-
]
|
1431
|
-
motor_time_indicators = motor_time_indicators.view(
|
1106
|
+
)[:motor_end_index].clamp(min=1e-3)
|
1107
|
+
motor_tte_event_indicators = motor_tte_event_indicators.reshape(
|
1432
1108
|
(-1, self.config.motor_num_time_pieces, self.config.motor_tte_vocab_size)
|
1433
|
-
)[:
|
1434
|
-
|
1109
|
+
)[:motor_end_index]
|
1110
|
+
# motor_tte_masks = motor_tte_masks.view(
|
1111
|
+
# (-1, self.config.motor_num_time_pieces, self.config.motor_tte_vocab_size)
|
1112
|
+
# )[:motor_end_index]
|
1113
|
+
|
1114
|
+
tte_features = hidden_states[motor_tte_task_indicators].view(
|
1115
|
+
(-1, self.config.n_embd)
|
1116
|
+
)
|
1117
|
+
|
1118
|
+
assert tte_features.shape[0] == motor_tte_times.shape[0], (
|
1435
1119
|
"The number of VE tokens in the labels needs to match up "
|
1436
1120
|
"with the first dimension of motor_time_to_event_vectors. "
|
1437
|
-
f"Received ve_token_features.shape[0]: {
|
1438
|
-
f"motor_time_to_event_vectors.shape[0]: {
|
1121
|
+
f"Received ve_token_features.shape[0]: {tte_features.shape[0]}, "
|
1122
|
+
f"motor_time_to_event_vectors.shape[0]: {motor_tte_times.shape[0]}"
|
1439
1123
|
)
|
1440
|
-
motor_time_to_event_vectors = motor_time_to_event_vectors[
|
1441
|
-
motor_time_to_event_to_include
|
1442
|
-
]
|
1443
|
-
motor_event_indicators = motor_event_indicators[motor_time_to_event_to_include]
|
1444
|
-
motor_time_indicators = motor_time_indicators[motor_time_to_event_to_include]
|
1445
|
-
ve_token_features = ve_token_features[motor_time_to_event_to_include]
|
1446
1124
|
|
1447
1125
|
# Get Exponential parameters from model
|
1448
|
-
|
1449
|
-
# (num_visits_in_batch, num_of_pieces, motor_vocab_size)
|
1450
|
-
dist = Exponential(lambda_p.clamp(min=1e-3))
|
1126
|
+
time_dependent_logits = self.motor_tte(tte_features)
|
1451
1127
|
|
1452
1128
|
# Compute event loss
|
1453
|
-
|
1454
|
-
|
1455
|
-
|
1456
|
-
|
1457
|
-
|
1458
|
-
)
|
1129
|
+
# Calculate the accumulative hazard
|
1130
|
+
# exp(-sum_{j} lambda_j)
|
1131
|
+
survival_loss = torch.exp2(time_dependent_logits + motor_tte_times).mean()
|
1132
|
+
event_loss = (
|
1133
|
+
-math.log(2)
|
1134
|
+
* torch.where(motor_tte_event_indicators, time_dependent_logits, 0).mean()
|
1459
1135
|
)
|
1460
|
-
|
1461
|
-
|
1136
|
+
|
1137
|
+
# survival_loss = (
|
1138
|
+
# torch.where(motor_tte_masks, lambda_p * motor_tte_times, 0)
|
1139
|
+
# .sum(dim=1)
|
1140
|
+
# .mean()
|
1141
|
+
# )
|
1142
|
+
# event_loss = (
|
1143
|
+
# -torch.where(motor_tte_event_indicators, torch.log(lambda_p), 0)
|
1144
|
+
# .sum(dim=1)
|
1145
|
+
# .mean()
|
1146
|
+
# )
|
1147
|
+
return survival_loss + event_loss
|
1462
1148
|
|
1463
1149
|
def forward(
|
1464
1150
|
self,
|
@@ -1467,7 +1153,6 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1467
1153
|
values: Optional[torch.LongTensor] = None,
|
1468
1154
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
1469
1155
|
attention_mask: Optional[torch.FloatTensor] = None,
|
1470
|
-
position_ids: Optional[torch.LongTensor] = None,
|
1471
1156
|
head_mask: Optional[torch.FloatTensor] = None,
|
1472
1157
|
random_vectors: Optional[torch.FloatTensor] = None,
|
1473
1158
|
labels: Optional[torch.LongTensor] = None,
|
@@ -1476,21 +1161,23 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1476
1161
|
time_to_visits: Optional[torch.FloatTensor] = None,
|
1477
1162
|
time_token_indicators: Optional[torch.BoolTensor] = None,
|
1478
1163
|
sub_time_tokens: Optional[torch.LongTensor] = None,
|
1479
|
-
|
1480
|
-
|
1481
|
-
|
1482
|
-
|
1164
|
+
motor_tte_times: Optional[torch.FloatTensor] = None,
|
1165
|
+
motor_tte_event_indicators: Optional[torch.BoolTensor] = None,
|
1166
|
+
motor_tte_task_indicators: Optional[torch.BoolTensor] = None,
|
1167
|
+
motor_tte_masks: Optional[torch.BoolTensor] = None,
|
1483
1168
|
motor_end_index: Optional[torch.LongTensor] = None,
|
1484
1169
|
use_cache: Optional[bool] = None,
|
1485
1170
|
output_attentions: Optional[bool] = None,
|
1486
1171
|
output_hidden_states: Optional[bool] = None,
|
1487
1172
|
return_dict: Optional[bool] = None,
|
1173
|
+
ages: Optional[torch.FloatTensor] = None,
|
1174
|
+
epoch_times: Optional[torch.FloatTensor] = None,
|
1488
1175
|
) -> Union[Tuple, CehrGptCausalLMOutput]:
|
1489
1176
|
r"""
|
1490
1177
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1491
|
-
|
1492
|
-
|
1493
|
-
|
1178
|
+
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
1179
|
+
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
1180
|
+
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
1494
1181
|
"""
|
1495
1182
|
return_dict = (
|
1496
1183
|
return_dict if return_dict is not None else self.config.use_return_dict
|
@@ -1502,7 +1189,7 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1502
1189
|
values=values,
|
1503
1190
|
past_key_values=past_key_values,
|
1504
1191
|
attention_mask=attention_mask,
|
1505
|
-
position_ids=
|
1192
|
+
position_ids=ages,
|
1506
1193
|
random_vectors=random_vectors,
|
1507
1194
|
head_mask=head_mask,
|
1508
1195
|
use_cache=use_cache,
|
@@ -1611,23 +1298,19 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1611
1298
|
|
1612
1299
|
if (
|
1613
1300
|
self.config.include_motor_time_to_event
|
1614
|
-
and
|
1615
|
-
and
|
1616
|
-
and
|
1617
|
-
and
|
1301
|
+
and motor_tte_times is not None
|
1302
|
+
and motor_tte_event_indicators is not None
|
1303
|
+
and motor_tte_task_indicators is not None
|
1304
|
+
and motor_tte_masks is not None
|
1618
1305
|
and motor_end_index is not None
|
1619
1306
|
):
|
1620
|
-
ve_token_id_indices = labels == self.config.ve_token_id
|
1621
|
-
ve_token_features = hidden_states[ve_token_id_indices]
|
1622
|
-
# Get rid of the last VE features because it's already reached the end of the patient sequence and
|
1623
|
-
# there is nothing to predict.
|
1624
1307
|
motor_tte_loss = self.motor_nll_loss(
|
1625
|
-
|
1626
|
-
|
1627
|
-
|
1628
|
-
|
1629
|
-
|
1630
|
-
|
1308
|
+
hidden_states=hidden_states,
|
1309
|
+
motor_tte_times=motor_tte_times,
|
1310
|
+
motor_tte_event_indicators=motor_tte_event_indicators,
|
1311
|
+
motor_tte_task_indicators=motor_tte_task_indicators,
|
1312
|
+
motor_tte_masks=motor_tte_masks,
|
1313
|
+
motor_end_index=motor_end_index,
|
1631
1314
|
)
|
1632
1315
|
loss += motor_tte_loss * self.config.motor_time_to_event_weight
|
1633
1316
|
|
@@ -1833,6 +1516,15 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1833
1516
|
else self.generation_config.return_dict_in_generate
|
1834
1517
|
)
|
1835
1518
|
|
1519
|
+
if "cehrgpt_tokenizer" not in model_kwargs:
|
1520
|
+
raise RuntimeError(
|
1521
|
+
"The cehr-gpt tokenizer must be provided to the "
|
1522
|
+
"model.generate(..., cehrgpt_tokenizer=cehrgpt_tokenizer)"
|
1523
|
+
)
|
1524
|
+
|
1525
|
+
# Remove this from the model_kwargs and will pass it to other functions explicitly
|
1526
|
+
cehrgpt_tokenizer = model_kwargs.pop("cehrgpt_tokenizer")
|
1527
|
+
|
1836
1528
|
# init attention / hidden states / scores tuples
|
1837
1529
|
scores = () if (return_dict_in_generate and output_scores) else None
|
1838
1530
|
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
@@ -1848,6 +1540,7 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1848
1540
|
|
1849
1541
|
# keep track of which sequences are already finished
|
1850
1542
|
batch_size, cur_len = input_ids.shape
|
1543
|
+
model_kwargs["attention_mask"] = input_ids != pad_token_id
|
1851
1544
|
if "inputs_embeds" in model_kwargs:
|
1852
1545
|
cur_len = model_kwargs["inputs_embeds"].shape[1]
|
1853
1546
|
this_peer_finished = False
|
@@ -1855,22 +1548,23 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1855
1548
|
batch_size, dtype=torch.long, device=input_ids.device
|
1856
1549
|
)
|
1857
1550
|
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
|
1858
|
-
#
|
1859
|
-
|
1860
|
-
lab_token_ids
|
1861
|
-
|
1862
|
-
|
1863
|
-
|
1551
|
+
# Getting the lab token ids
|
1552
|
+
lab_token_ids = torch.tensor(
|
1553
|
+
cehrgpt_tokenizer.lab_token_ids,
|
1554
|
+
dtype=torch.int32,
|
1555
|
+
)
|
1556
|
+
if model_kwargs.get("value_indicators", None) is not None:
|
1557
|
+
value_indicators = model_kwargs.get("value_indicators")
|
1864
1558
|
else:
|
1865
|
-
|
1866
|
-
|
1559
|
+
value_indicators = torch.zeros_like(input_ids).to(torch.bool)
|
1560
|
+
|
1561
|
+
if model_kwargs.get("values", None) is not None:
|
1562
|
+
values = model_kwargs.get("values")
|
1563
|
+
else:
|
1564
|
+
values = torch.zeros_like(
|
1565
|
+
input_ids,
|
1867
1566
|
dtype=torch.int32,
|
1868
1567
|
)
|
1869
|
-
value_indicators = torch.zeros_like(input_ids).to(torch.bool)
|
1870
|
-
values = torch.zeros_like(
|
1871
|
-
input_ids,
|
1872
|
-
dtype=torch.int32,
|
1873
|
-
)
|
1874
1568
|
# Generate initial random_vectors
|
1875
1569
|
if self.cehrgpt.config.causal_sfm:
|
1876
1570
|
model_kwargs["random_vectors"] = torch.rand(
|
@@ -1884,11 +1578,33 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1884
1578
|
model_kwargs["random_vectors"] = None
|
1885
1579
|
model_kwargs["value_indicators"] = value_indicators
|
1886
1580
|
model_kwargs["values"] = values
|
1581
|
+
|
1582
|
+
# A variable to keep track of time and initialize it to zero
|
1583
|
+
batched_time_delta = np.zeros((batch_size,), dtype=np.float32)
|
1584
|
+
batched_ages = model_kwargs.get("ages", None)
|
1585
|
+
if batched_ages is None:
|
1586
|
+
batched_ages = []
|
1587
|
+
for token_ids in input_ids.detach().cpu():
|
1588
|
+
concept_ids = cehrgpt_tokenizer.decode(
|
1589
|
+
token_ids.numpy(), skip_special_tokens=False
|
1590
|
+
)
|
1591
|
+
batched_ages.append(construct_age_sequence(concept_ids))
|
1592
|
+
# Turn this to a numpy array for easy manipulation
|
1593
|
+
batched_ages = np.asarray(batched_ages)
|
1594
|
+
else:
|
1595
|
+
batched_ages = batched_ages.cpu().numpy()
|
1596
|
+
# This is the base to which we will add the time delta
|
1597
|
+
base_ages = np.asarray([ages[-1] for ages in batched_ages])
|
1598
|
+
# Update the keyword arguments for the prepare_inputs_for_generation
|
1599
|
+
model_kwargs["ages"] = torch.tensor(batched_ages).to(input_ids.device)
|
1600
|
+
|
1887
1601
|
while self._has_unfinished_sequences(
|
1888
1602
|
this_peer_finished, synced_gpus, device=input_ids.device
|
1889
1603
|
):
|
1890
1604
|
# prepare model inputs
|
1891
|
-
model_inputs = self.prepare_inputs_for_generation(
|
1605
|
+
model_inputs = self.prepare_inputs_for_generation(
|
1606
|
+
input_ids, cehrgpt_tokenizer, **model_kwargs
|
1607
|
+
)
|
1892
1608
|
|
1893
1609
|
# forward pass to get next token
|
1894
1610
|
outputs = self(
|
@@ -1933,6 +1649,22 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1933
1649
|
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
1934
1650
|
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
1935
1651
|
|
1652
|
+
# TODO: decode to get time tokens and recalculate the age at this time step
|
1653
|
+
# Look for a potential time token
|
1654
|
+
for batch_i, next_concept_id in enumerate(
|
1655
|
+
cehrgpt_tokenizer.decode(
|
1656
|
+
next_tokens.detach().cpu().numpy(), skip_special_tokens=False
|
1657
|
+
)
|
1658
|
+
):
|
1659
|
+
if is_att_token(next_concept_id):
|
1660
|
+
batched_time_delta[batch_i] += extract_time_interval_in_days(
|
1661
|
+
next_concept_id
|
1662
|
+
)
|
1663
|
+
|
1664
|
+
next_age = (base_ages + batched_time_delta // 365).astype(int)[..., None]
|
1665
|
+
batched_ages = np.concatenate([batched_ages, next_age], axis=-1)
|
1666
|
+
model_kwargs["ages"] = torch.tensor(batched_ages).to(input_ids.device)
|
1667
|
+
|
1936
1668
|
# finished sentences should have their next token be a padding token
|
1937
1669
|
if eos_token_id is not None:
|
1938
1670
|
if pad_token_id is None:
|
@@ -1968,6 +1700,7 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1968
1700
|
|
1969
1701
|
if streamer is not None:
|
1970
1702
|
streamer.put(next_tokens.cpu())
|
1703
|
+
|
1971
1704
|
model_kwargs = self._update_model_kwargs_for_generation(
|
1972
1705
|
outputs,
|
1973
1706
|
model_kwargs,
|
@@ -2012,7 +1745,7 @@ class FocalLoss(nn.Module):
|
|
2012
1745
|
self.reduction = reduction
|
2013
1746
|
|
2014
1747
|
def forward(self, logits, targets):
|
2015
|
-
bce_loss =
|
1748
|
+
bce_loss = f.binary_cross_entropy_with_logits(logits, targets, reduction="none")
|
2016
1749
|
probs = torch.sigmoid(logits)
|
2017
1750
|
pt = torch.where(targets == 1, probs, 1 - probs)
|
2018
1751
|
focal_term = (1 - pt) ** self.gamma
|
@@ -2094,12 +1827,13 @@ class CehrGptForClassification(CEHRGPTPreTrainedModel):
|
|
2094
1827
|
values: Optional[torch.LongTensor] = None,
|
2095
1828
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
2096
1829
|
attention_mask: Optional[torch.FloatTensor] = None,
|
2097
|
-
position_ids: Optional[torch.LongTensor] = None,
|
2098
1830
|
head_mask: Optional[torch.FloatTensor] = None,
|
2099
1831
|
use_cache: Optional[bool] = None,
|
2100
1832
|
output_attentions: Optional[bool] = None,
|
2101
1833
|
output_hidden_states: Optional[bool] = None,
|
2102
1834
|
return_dict: Optional[bool] = None,
|
1835
|
+
ages: Optional[torch.FloatTensor] = None,
|
1836
|
+
epoch_times: Optional[torch.FloatTensor] = None,
|
2103
1837
|
**kwargs,
|
2104
1838
|
) -> CehrGptSequenceClassifierOutput:
|
2105
1839
|
|
@@ -2109,12 +1843,12 @@ class CehrGptForClassification(CEHRGPTPreTrainedModel):
|
|
2109
1843
|
values=values,
|
2110
1844
|
past_key_values=past_key_values,
|
2111
1845
|
attention_mask=attention_mask,
|
2112
|
-
position_ids=position_ids,
|
2113
1846
|
head_mask=head_mask,
|
2114
1847
|
use_cache=use_cache,
|
2115
1848
|
output_attentions=output_attentions,
|
2116
1849
|
output_hidden_states=output_hidden_states,
|
2117
1850
|
return_dict=return_dict,
|
1851
|
+
position_ids=ages,
|
2118
1852
|
)
|
2119
1853
|
|
2120
1854
|
if is_sample_pack(attention_mask):
|