cehrgpt 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/htn_treatment_pathway.py +546 -0
- cehrgpt/analysis/treatment_pathway/__init__.py +0 -0
- cehrgpt/analysis/treatment_pathway/depression_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/diabetes_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/htn_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/treatment_pathway.py +631 -0
- cehrgpt/data/cehrgpt_data_processor.py +549 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +4 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +285 -652
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +38 -5
- cehrgpt/generation/cehrgpt_conditional_generation.py +2 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +20 -12
- cehrgpt/generation/omop_converter_batch.py +11 -4
- cehrgpt/gpt_utils.py +73 -3
- cehrgpt/models/activations.py +27 -0
- cehrgpt/models/config.py +6 -2
- cehrgpt/models/gpt2.py +560 -0
- cehrgpt/models/hf_cehrgpt.py +183 -460
- cehrgpt/models/tokenization_hf_cehrgpt.py +380 -50
- cehrgpt/omop/ontology.py +154 -0
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +24 -78
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +48 -44
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +46 -34
- cehrgpt/runners/hyperparameter_search_util.py +180 -69
- cehrgpt/runners/sample_packing_trainer.py +11 -2
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +8 -2
- cehrgpt-0.1.4.dist-info/METADATA +238 -0
- {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/RECORD +32 -22
- cehrgpt-0.1.2.dist-info/METADATA +0 -209
- /cehrgpt/tools/{merge_synthetic_real_dataasets.py → merge_synthetic_real_datasets.py} +0 -0
- {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/WHEEL +0 -0
- {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/licenses/LICENSE +0 -0
- {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/top_level.txt +0 -0
cehrgpt/models/hf_cehrgpt.py
CHANGED
@@ -6,9 +6,8 @@ import numpy as np
|
|
6
6
|
import torch
|
7
7
|
import torch.nn.functional as f
|
8
8
|
from torch import nn
|
9
|
-
from torch.distributions import
|
9
|
+
from torch.distributions import Gamma
|
10
10
|
from torch.nn import CrossEntropyLoss
|
11
|
-
from torch.nn import functional as F
|
12
11
|
from transformers import PreTrainedModel
|
13
12
|
from transformers.activations import gelu_new
|
14
13
|
from transformers.generation.logits_process import LogitsProcessorList
|
@@ -18,16 +17,20 @@ from transformers.generation.stopping_criteria import (
|
|
18
17
|
)
|
19
18
|
from transformers.generation.streamers import BaseStreamer
|
20
19
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
21
|
-
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block
|
22
20
|
from transformers.pytorch_utils import Conv1D
|
23
|
-
from transformers.utils import
|
24
|
-
is_accelerate_available,
|
25
|
-
is_flash_attn_2_available,
|
26
|
-
logging,
|
27
|
-
)
|
21
|
+
from transformers.utils import is_flash_attn_2_available, logging
|
28
22
|
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
|
29
23
|
|
24
|
+
from cehrgpt.gpt_utils import (
|
25
|
+
construct_age_sequence,
|
26
|
+
encode_demographics,
|
27
|
+
extract_time_interval_in_days,
|
28
|
+
is_att_token,
|
29
|
+
multiple_of_10,
|
30
|
+
)
|
31
|
+
from cehrgpt.models.activations import RMSNorm
|
30
32
|
from cehrgpt.models.config import CEHRGPTConfig
|
33
|
+
from cehrgpt.models.gpt2 import GPT2Block, is_sample_pack
|
31
34
|
from cehrgpt.models.hf_modeling_outputs import (
|
32
35
|
CehrGptCausalLMOutput,
|
33
36
|
CehrGptGenerateDecoderOnlyOutput,
|
@@ -35,13 +38,6 @@ from cehrgpt.models.hf_modeling_outputs import (
|
|
35
38
|
CehrGptSequenceClassifierOutput,
|
36
39
|
)
|
37
40
|
|
38
|
-
if is_flash_attn_2_available():
|
39
|
-
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
40
|
-
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
41
|
-
|
42
|
-
if is_accelerate_available():
|
43
|
-
from accelerate.hooks import add_hook_to_module
|
44
|
-
|
45
41
|
logger = logging.get_logger(__name__)
|
46
42
|
|
47
43
|
|
@@ -50,7 +46,7 @@ def extract_features_from_packed_sequence(
|
|
50
46
|
attention_mask: torch.Tensor,
|
51
47
|
) -> torch.Tensor:
|
52
48
|
max_index = attention_mask.nonzero(as_tuple=False).flatten()[-1]
|
53
|
-
padded_attention_mask =
|
49
|
+
padded_attention_mask = f.pad(attention_mask[:, : max_index + 1], (0, 1))
|
54
50
|
feature_indices = torch.nonzero(padded_attention_mask == 0)[:, 1] - 1
|
55
51
|
return hidden_state[:, feature_indices]
|
56
52
|
|
@@ -81,315 +77,41 @@ def create_sample_packing_attention_mask(attention_mask: torch.Tensor) -> torch.
|
|
81
77
|
return attn_matrix
|
82
78
|
|
83
79
|
|
84
|
-
|
85
|
-
|
86
|
-
Determines whether any sequence in the batch is likely sample-packed.
|
87
|
-
|
88
|
-
A sample-packed sequence is one where there are non-padding (1) tokens
|
89
|
-
after a padding (0) token, indicating multiple sequences packed together
|
90
|
-
with padding as a separator.
|
91
|
-
|
92
|
-
Args:
|
93
|
-
attention_mask (torch.Tensor): A tensor of shape (batch_size, seq_len)
|
94
|
-
where 1 indicates a real token and 0 indicates padding.
|
95
|
-
|
96
|
-
Returns:
|
97
|
-
bool: True if any sample in the batch is sample-packed, False otherwise.
|
98
|
-
"""
|
99
|
-
|
100
|
-
# If the attention_maks is left padded, we will flip it so we can use the same logic below
|
101
|
-
if (attention_mask[:, 0] == 0).any():
|
102
|
-
attention_mask = attention_mask.flip(dims=[1])
|
103
|
-
|
104
|
-
nonzero_counts = attention_mask.sum(dim=1)
|
105
|
-
max_token_positions = torch.argmax(
|
106
|
-
attention_mask.to(torch.int32).flip(dims=[1]), dim=1
|
107
|
-
)
|
108
|
-
max_indices = attention_mask.shape[1] - 1 - max_token_positions
|
109
|
-
return torch.any(nonzero_counts < (max_indices + 1)).item()
|
110
|
-
|
111
|
-
|
112
|
-
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
113
|
-
def _get_unpad_data(attention_mask):
|
114
|
-
# This infers sample packing
|
115
|
-
if is_sample_pack(attention_mask):
|
116
|
-
# Assume input: attention_mask shape = (batch, seq_len)
|
117
|
-
attention_mask = attention_mask.flatten() # shape: (seq_len,)
|
118
|
-
|
119
|
-
# Compute max_index of the last non-zero element
|
120
|
-
nonzero = torch.nonzero(attention_mask, as_tuple=False).flatten()
|
121
|
-
max_index = nonzero[-1].item()
|
122
|
-
|
123
|
-
# Pad the truncated attention mask
|
124
|
-
padded_attention_mask = F.pad(attention_mask[: max_index + 1], (0, 1), value=0)
|
125
|
-
|
126
|
-
# Indices of all tokens
|
127
|
-
indices = torch.nonzero(attention_mask, as_tuple=False).flatten()
|
128
|
-
|
129
|
-
# Find where 0s occur (segment boundaries)
|
130
|
-
cumsum_seqlens_in_batch = torch.cumsum(padded_attention_mask, dim=0)[
|
131
|
-
padded_attention_mask == 0
|
132
|
-
]
|
133
|
-
|
134
|
-
# Compute seqlens per segment
|
135
|
-
seqlens_in_batch = (
|
136
|
-
cumsum_seqlens_in_batch
|
137
|
-
- F.pad(cumsum_seqlens_in_batch, (1, 0), value=0)[:-1]
|
138
|
-
).to(torch.int)
|
139
|
-
|
140
|
-
max_seqlen_in_batch = (
|
141
|
-
seqlens_in_batch.max().item() if seqlens_in_batch.numel() > 0 else 0
|
142
|
-
)
|
143
|
-
cu_seqlens = F.pad(cumsum_seqlens_in_batch, (1, 0)).to(torch.int)
|
144
|
-
else:
|
145
|
-
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
146
|
-
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
147
|
-
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
148
|
-
cu_seqlens = F.pad(
|
149
|
-
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
150
|
-
)
|
151
|
-
|
152
|
-
return (
|
153
|
-
indices,
|
154
|
-
cu_seqlens,
|
155
|
-
max_seqlen_in_batch,
|
156
|
-
)
|
157
|
-
|
158
|
-
|
159
|
-
class GPT2FlashAttention(GPT2Attention):
|
160
|
-
"""
|
161
|
-
GPT2FlashAttention inherits from `GPT2Attention`.
|
162
|
-
|
163
|
-
The primary change is in the forward pass, where it correctly
|
164
|
-
calls the public API of flash attention and handles padding tokens.
|
165
|
-
"""
|
166
|
-
|
167
|
-
def forward(
|
168
|
-
self,
|
169
|
-
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
170
|
-
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
171
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
172
|
-
head_mask: Optional[torch.FloatTensor] = None,
|
173
|
-
encoder_hidden_states: Optional[torch.Tensor] = None,
|
174
|
-
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
175
|
-
use_cache: Optional[bool] = False,
|
176
|
-
output_attentions: Optional[bool] = False,
|
177
|
-
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
178
|
-
# Prepare query, key, and value
|
179
|
-
if encoder_hidden_states is not None:
|
180
|
-
if not hasattr(self, "q_attn"):
|
181
|
-
raise ValueError(
|
182
|
-
"If class is used as cross attention, the weights `q_attn` have to be defined. "
|
183
|
-
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
|
184
|
-
)
|
185
|
-
|
186
|
-
query = self.q_attn(hidden_states)
|
187
|
-
key, value = self.c_attn(encoder_hidden_states).split(
|
188
|
-
self.split_size, dim=2
|
189
|
-
)
|
190
|
-
attention_mask = encoder_attention_mask
|
191
|
-
else:
|
192
|
-
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
193
|
-
|
194
|
-
query = self._split_heads(query, self.num_heads, self.head_dim)
|
195
|
-
key = self._split_heads(key, self.num_heads, self.head_dim)
|
196
|
-
value = self._split_heads(value, self.num_heads, self.head_dim)
|
197
|
-
|
198
|
-
if layer_past is not None:
|
199
|
-
past_key, past_value = layer_past
|
200
|
-
key = torch.cat((past_key, key), dim=-2)
|
201
|
-
value = torch.cat((past_value, value), dim=-2)
|
202
|
-
|
203
|
-
if use_cache is True:
|
204
|
-
present = (key, value)
|
205
|
-
else:
|
206
|
-
present = None
|
207
|
-
|
208
|
-
# Apply Flash Attention Forward
|
209
|
-
if self.reorder_and_upcast_attn:
|
210
|
-
attn_output, attn_weights = self._upcast_and_reordered_attn(
|
211
|
-
query, key, value, attention_mask, head_mask
|
212
|
-
)
|
213
|
-
else:
|
214
|
-
# Flash Attention forward pass
|
215
|
-
attn_output = self._flash_attention_forward(
|
216
|
-
query,
|
217
|
-
key,
|
218
|
-
value,
|
219
|
-
attention_mask,
|
220
|
-
query.size(-2),
|
221
|
-
self.attn_dropout.p,
|
222
|
-
softmax_scale=None,
|
223
|
-
)
|
224
|
-
|
225
|
-
# Merge heads and project back to hidden size
|
226
|
-
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
227
|
-
attn_output = self.c_proj(attn_output)
|
228
|
-
attn_output = self.resid_dropout(attn_output)
|
229
|
-
|
230
|
-
outputs = (attn_output, present)
|
231
|
-
if output_attentions:
|
232
|
-
outputs += (attn_weights,)
|
233
|
-
|
234
|
-
return outputs
|
235
|
-
|
236
|
-
def _flash_attention_forward(
|
80
|
+
class MotorTaskHead(nn.Module):
|
81
|
+
def __init__(
|
237
82
|
self,
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
query_length,
|
243
|
-
dropout=0.0,
|
244
|
-
softmax_scale=None,
|
83
|
+
input_dim,
|
84
|
+
motor_tte_vocab_size,
|
85
|
+
motor_num_time_pieces,
|
86
|
+
eps=1e-6,
|
245
87
|
):
|
246
|
-
"""
|
247
|
-
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token.
|
248
|
-
|
249
|
-
first unpad the input, then computes the attention scores and pad the final attention scores.
|
250
|
-
Args:
|
251
|
-
query_states (`torch.Tensor`):
|
252
|
-
Input query states to be passed to Flash Attention API
|
253
|
-
key_states (`torch.Tensor`):
|
254
|
-
Input key states to be passed to Flash Attention API
|
255
|
-
value_states (`torch.Tensor`):
|
256
|
-
Input value states to be passed to Flash Attention API
|
257
|
-
attention_mask (`torch.Tensor`):
|
258
|
-
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
259
|
-
position of padding tokens and 1 for the position of non-padding tokens.
|
260
|
-
dropout (`int`, *optional*):
|
261
|
-
Attention dropout
|
262
|
-
softmax_scale (`float`, *optional*):
|
263
|
-
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
264
|
-
"""
|
265
|
-
|
266
|
-
# Flash attention requires the input to have the shape
|
267
|
-
# batch_size x seq_length x head_dim x hidden_dim
|
268
|
-
# therefore we just need to keep the original shape
|
269
|
-
dtype = query_states.dtype
|
270
|
-
query_states = query_states.permute(0, 2, 1, 3).contiguous().to(torch.bfloat16)
|
271
|
-
key_states = key_states.permute(0, 2, 1, 3).contiguous().to(torch.bfloat16)
|
272
|
-
value_states = value_states.permute(0, 2, 1, 3).contiguous().to(torch.bfloat16)
|
273
|
-
|
274
|
-
# Contains at least one padding token in the sequence
|
275
|
-
if attention_mask is not None:
|
276
|
-
batch_size = query_states.shape[0]
|
277
|
-
|
278
|
-
(
|
279
|
-
query_states,
|
280
|
-
key_states,
|
281
|
-
value_states,
|
282
|
-
indices_q,
|
283
|
-
cu_seq_lens,
|
284
|
-
max_seq_lens,
|
285
|
-
) = self._upad_input(
|
286
|
-
query_states, key_states, value_states, attention_mask, query_length
|
287
|
-
)
|
288
|
-
|
289
|
-
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
290
|
-
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
291
|
-
|
292
|
-
attn_output_unpad = flash_attn_varlen_func(
|
293
|
-
query_states,
|
294
|
-
key_states,
|
295
|
-
value_states,
|
296
|
-
cu_seqlens_q=cu_seqlens_q,
|
297
|
-
cu_seqlens_k=cu_seqlens_k,
|
298
|
-
max_seqlen_q=max_seqlen_in_batch_q,
|
299
|
-
max_seqlen_k=max_seqlen_in_batch_k,
|
300
|
-
dropout_p=dropout,
|
301
|
-
softmax_scale=softmax_scale,
|
302
|
-
causal=True,
|
303
|
-
)
|
304
|
-
# (batch, seq_length, n_heads, head_dim)
|
305
|
-
attn_output = pad_input(
|
306
|
-
attn_output_unpad, indices_q, batch_size, query_length
|
307
|
-
)
|
308
|
-
else:
|
309
|
-
attn_output = flash_attn_func(
|
310
|
-
query_states,
|
311
|
-
key_states,
|
312
|
-
value_states,
|
313
|
-
dropout,
|
314
|
-
softmax_scale=softmax_scale,
|
315
|
-
causal=self.is_causal,
|
316
|
-
)
|
317
|
-
# re-order the tensor back to (batch, n_heads, seq_length, head_dim)
|
318
|
-
return attn_output.permute(0, 2, 1, 3).contiguous().to(dtype)
|
319
|
-
|
320
|
-
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
|
321
|
-
def _upad_input(
|
322
|
-
self, query_layer, key_layer, value_layer, attention_mask, query_length
|
323
|
-
):
|
324
|
-
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
325
|
-
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
326
|
-
|
327
|
-
key_layer = index_first_axis(
|
328
|
-
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
329
|
-
indices_k,
|
330
|
-
)
|
331
|
-
value_layer = index_first_axis(
|
332
|
-
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
333
|
-
indices_k,
|
334
|
-
)
|
335
|
-
if query_length == kv_seq_len:
|
336
|
-
query_layer = index_first_axis(
|
337
|
-
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
|
338
|
-
indices_k,
|
339
|
-
)
|
340
|
-
cu_seqlens_q = cu_seqlens_k
|
341
|
-
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
342
|
-
indices_q = indices_k
|
343
|
-
elif query_length == 1:
|
344
|
-
max_seqlen_in_batch_q = 1
|
345
|
-
cu_seqlens_q = torch.arange(
|
346
|
-
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
347
|
-
) # There is a memcpy here, that is very bad.
|
348
|
-
indices_q = cu_seqlens_q[:-1]
|
349
|
-
query_layer = query_layer.squeeze(1)
|
350
|
-
else:
|
351
|
-
# The -q_len: slice assumes left padding.
|
352
|
-
attention_mask = attention_mask[:, -query_length:]
|
353
|
-
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
|
354
|
-
query_layer, attention_mask
|
355
|
-
)
|
356
|
-
|
357
|
-
return (
|
358
|
-
query_layer,
|
359
|
-
key_layer,
|
360
|
-
value_layer,
|
361
|
-
indices_q,
|
362
|
-
(cu_seqlens_q, cu_seqlens_k),
|
363
|
-
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
364
|
-
)
|
365
|
-
|
366
|
-
|
367
|
-
class MotorTaskHead(nn.Module):
|
368
|
-
def __init__(self, input_dim, motor_tte_vocab_size, motor_num_time_pieces):
|
369
88
|
super(MotorTaskHead, self).__init__()
|
370
89
|
self.input_dim = input_dim
|
371
90
|
self.motor_tte_vocab_size = motor_tte_vocab_size
|
372
91
|
self.motor_num_time_pieces = motor_num_time_pieces
|
373
|
-
self.
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
),
|
92
|
+
self.final_layer = nn.Linear(input_dim, input_dim * motor_num_time_pieces)
|
93
|
+
self.norm = RMSNorm(input_dim, eps)
|
94
|
+
self.task_layer = nn.Linear(input_dim, motor_tte_vocab_size)
|
95
|
+
self.task_time_bias = nn.Parameter(
|
96
|
+
torch.zeros(1, self.motor_num_time_pieces, motor_tte_vocab_size)
|
379
97
|
)
|
380
98
|
|
381
99
|
def forward(self, x):
|
382
100
|
# Ensure scale is positive
|
383
101
|
length = x.shape[0]
|
384
102
|
# (num_visits_in_batch, motor_tte_vocab_size * motor_num_time_pieces)
|
385
|
-
|
386
|
-
|
387
|
-
if torch.isnan(lambda_p).any():
|
388
|
-
logger.warning(f"NaN values found in scale_param. x: {x}")
|
389
|
-
# (num_visits_in_batch, motor_num_time_pieces, motor_tte_vocab_size,)
|
390
|
-
return lambda_p.view(
|
391
|
-
length, self.motor_num_time_pieces, self.motor_tte_vocab_size
|
103
|
+
x = self.final_layer(x).reshape(
|
104
|
+
length, self.motor_num_time_pieces, self.input_dim
|
392
105
|
)
|
106
|
+
x = self.norm(x)
|
107
|
+
x = self.task_layer(x) + self.task_time_bias
|
108
|
+
# lambda_p = f.softplus(x)
|
109
|
+
|
110
|
+
# # Check for NaN values
|
111
|
+
# if torch.isnan(lambda_p).any():
|
112
|
+
# logger.warning(f"NaN values found in scale_param. x: {x}")
|
113
|
+
# # (num_visits_in_batch, motor_num_time_pieces, motor_tte_vocab_size,)
|
114
|
+
return x
|
393
115
|
|
394
116
|
|
395
117
|
class VisitTimeToEventHead(nn.Module):
|
@@ -725,7 +447,6 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
725
447
|
def __init__(self, config: CEHRGPTConfig):
|
726
448
|
super().__init__(config)
|
727
449
|
|
728
|
-
self.exclude_position_ids = config.exclude_position_ids
|
729
450
|
self.include_values = config.include_values
|
730
451
|
self.include_ttv_prediction = config.include_ttv_prediction
|
731
452
|
self.embed_dim = config.hidden_size
|
@@ -736,8 +457,6 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
736
457
|
self.pretrained_wte = None
|
737
458
|
|
738
459
|
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
|
739
|
-
if not self.exclude_position_ids:
|
740
|
-
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
741
460
|
if self.include_values:
|
742
461
|
self.vte = nn.Embedding(config.value_vocab_size, self.embed_dim)
|
743
462
|
self.concept_value_transformation_layer = ConceptValueTransformationLayer(
|
@@ -748,9 +467,7 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
748
467
|
gpt_blocks = []
|
749
468
|
for i in range(config.num_hidden_layers):
|
750
469
|
gpt_block = GPT2Block(config, layer_idx=i)
|
751
|
-
|
752
|
-
gpt_block.attn = GPT2FlashAttention(config, layer_idx=i)
|
753
|
-
gpt_block.is_causal = True
|
470
|
+
gpt_block.is_causal = True
|
754
471
|
gpt_blocks.append(gpt_block)
|
755
472
|
self.h = nn.ModuleList(gpt_blocks)
|
756
473
|
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
@@ -771,10 +488,6 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
771
488
|
)
|
772
489
|
self.update_attn_bias(self.config.sample_packing_max_positions)
|
773
490
|
|
774
|
-
def enable_position_embeddings(self):
|
775
|
-
self.wpe = nn.Embedding(self.config.max_position_embeddings, self.embed_dim)
|
776
|
-
self.config.exclude_position_ids = False
|
777
|
-
|
778
491
|
def initialize_pretrained_embeddings(self):
|
779
492
|
layers = [
|
780
493
|
nn.Embedding(self.config.vocab_size, self.config.pretrained_embedding_dim),
|
@@ -817,8 +530,6 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
817
530
|
self.wte = self.wte.to(self.first_device)
|
818
531
|
if self.config.use_pretrained_embeddings:
|
819
532
|
self.pretrained_wte = self.pretrained_wte.to(self.first_device)
|
820
|
-
if not self.exclude_position_ids:
|
821
|
-
self.wpe = self.wpe.to(self.first_device)
|
822
533
|
if self.include_values:
|
823
534
|
self.vte = self.vte.to(self.first_device)
|
824
535
|
self.concept_value_transformation_layer = (
|
@@ -844,8 +555,6 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
844
555
|
self.wte = self.wte.to("cpu")
|
845
556
|
if self.config.use_pretrained_embeddings:
|
846
557
|
self.pretrained_wte = self.pretrained_wte.to("cpu")
|
847
|
-
if not self.exclude_position_ids:
|
848
|
-
self.wpe = self.wpe.to("cpu")
|
849
558
|
self.vte = self.vte.to("cpu")
|
850
559
|
self.concept_value_transformation_layer = (
|
851
560
|
self.concept_value_transformation_layer.to("cpu")
|
@@ -873,8 +582,6 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
873
582
|
def get_position_embeddings(
|
874
583
|
self,
|
875
584
|
) -> Optional[Union[nn.Embedding, Tuple[nn.Embedding]]]:
|
876
|
-
if not self.exclude_position_ids:
|
877
|
-
return self.wpe
|
878
585
|
return None
|
879
586
|
|
880
587
|
def set_position_embeddings(self, new_embeddings: nn.Embedding):
|
@@ -948,24 +655,12 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
948
655
|
# Convert list back to torch.Size if needed
|
949
656
|
input_shape = torch.Size(shape_list)
|
950
657
|
|
951
|
-
|
658
|
+
input_ids.device
|
952
659
|
|
953
660
|
if past_key_values is None:
|
954
|
-
past_length = 0
|
955
661
|
past_key_values = tuple([None] * len(self.h))
|
956
662
|
else:
|
957
|
-
|
958
|
-
|
959
|
-
# This is normally called during training or fine-tuning.
|
960
|
-
# While the generation logic will handle position_ids in the sampling logic
|
961
|
-
if position_ids is None and not self.exclude_position_ids:
|
962
|
-
position_ids = torch.arange(
|
963
|
-
past_length,
|
964
|
-
input_shape[-1] + past_length,
|
965
|
-
dtype=torch.long,
|
966
|
-
device=device,
|
967
|
-
)
|
968
|
-
position_ids = position_ids.unsqueeze(0)
|
663
|
+
past_key_values[0][0].size(-2)
|
969
664
|
|
970
665
|
# GPT2Attention mask.
|
971
666
|
if attention_mask is not None:
|
@@ -1048,10 +743,19 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
1048
743
|
]
|
1049
744
|
if random_vectors is None:
|
1050
745
|
random_vectors = torch.rand_like(input_embeddings[:, :1])
|
746
|
+
|
1051
747
|
input_embeddings = torch.concat(
|
1052
748
|
[demographic_embeddings, random_vectors, medical_event_embeddings],
|
1053
749
|
dim=1,
|
1054
750
|
)
|
751
|
+
position_ids = torch.concat(
|
752
|
+
[
|
753
|
+
position_ids[:, : self.config.demographics_size],
|
754
|
+
position_ids[:, :1],
|
755
|
+
position_ids[:, self.config.demographics_size :],
|
756
|
+
],
|
757
|
+
dim=1,
|
758
|
+
)
|
1055
759
|
|
1056
760
|
if self.include_values:
|
1057
761
|
if (
|
@@ -1077,13 +781,7 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
1077
781
|
value_embeddings=value_embeddings,
|
1078
782
|
)
|
1079
783
|
|
1080
|
-
|
1081
|
-
position_embeds = self.wpe(position_ids).to(input_embeddings.dtype)
|
1082
|
-
hidden_states = input_embeddings + position_embeds
|
1083
|
-
else:
|
1084
|
-
hidden_states = input_embeddings
|
1085
|
-
|
1086
|
-
hidden_states = self.drop(hidden_states)
|
784
|
+
hidden_states = self.drop(input_embeddings)
|
1087
785
|
|
1088
786
|
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
|
1089
787
|
|
@@ -1111,6 +809,7 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
1111
809
|
attention_mask = attention_mask.to(hidden_states.device)
|
1112
810
|
if isinstance(head_mask, torch.Tensor):
|
1113
811
|
head_mask = head_mask.to(hidden_states.device)
|
812
|
+
|
1114
813
|
if output_hidden_states:
|
1115
814
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
1116
815
|
|
@@ -1118,6 +817,7 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
1118
817
|
outputs = self._gradient_checkpointing_func(
|
1119
818
|
block.__call__,
|
1120
819
|
hidden_states,
|
820
|
+
position_ids,
|
1121
821
|
None,
|
1122
822
|
attention_mask,
|
1123
823
|
head_mask[i],
|
@@ -1129,6 +829,7 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
1129
829
|
else:
|
1130
830
|
outputs = block(
|
1131
831
|
hidden_states,
|
832
|
+
position_ids=position_ids,
|
1132
833
|
layer_past=layer_past,
|
1133
834
|
attention_mask=attention_mask,
|
1134
835
|
head_mask=head_mask[i],
|
@@ -1202,7 +903,9 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1202
903
|
|
1203
904
|
if self.config.include_motor_time_to_event:
|
1204
905
|
self.motor_tte = MotorTaskHead(
|
1205
|
-
config.n_embd,
|
906
|
+
input_dim=config.n_embd,
|
907
|
+
motor_tte_vocab_size=config.motor_tte_vocab_size,
|
908
|
+
motor_num_time_pieces=config.motor_num_time_pieces,
|
1206
909
|
)
|
1207
910
|
|
1208
911
|
# Model parallel
|
@@ -1302,12 +1005,12 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1302
1005
|
def prepare_inputs_for_generation(
|
1303
1006
|
self,
|
1304
1007
|
input_ids,
|
1008
|
+
cehrgpt_tokenizer,
|
1305
1009
|
past_key_values=None,
|
1306
1010
|
inputs_embeds=None,
|
1307
|
-
lab_token_ids=None,
|
1308
1011
|
**kwargs,
|
1309
1012
|
):
|
1310
|
-
|
1013
|
+
ages = kwargs.get("ages")
|
1311
1014
|
# Omit tokens covered by past_key_values
|
1312
1015
|
if past_key_values:
|
1313
1016
|
past_length = past_key_values[0][0].shape[2]
|
@@ -1322,33 +1025,10 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1322
1025
|
remove_prefix_length = input_ids.shape[1] - 1
|
1323
1026
|
|
1324
1027
|
input_ids = input_ids[:, remove_prefix_length:]
|
1028
|
+
ages = ages[:, remove_prefix_length:]
|
1325
1029
|
|
1326
1030
|
attention_mask = kwargs.get("attention_mask", None)
|
1327
|
-
position_ids = kwargs.get("position_ids", None)
|
1328
1031
|
random_vectors = kwargs.get("random_vectors", None)
|
1329
|
-
|
1330
|
-
if attention_mask is not None and position_ids is None:
|
1331
|
-
# create position_ids on the fly for batch generation
|
1332
|
-
position_ids = attention_mask.long().cumsum(-1) - 1
|
1333
|
-
position_ids.masked_fill_(attention_mask == 0, 1)
|
1334
|
-
if past_key_values:
|
1335
|
-
position_ids = position_ids[:, -input_ids.shape[1] :]
|
1336
|
-
|
1337
|
-
# Add one more position for the random vectors
|
1338
|
-
if (
|
1339
|
-
self.cehrgpt.config.causal_sfm
|
1340
|
-
and position_ids.shape[-1] >= self.cehrgpt.config.demographics_size
|
1341
|
-
):
|
1342
|
-
position_ids = torch.concat(
|
1343
|
-
[
|
1344
|
-
position_ids,
|
1345
|
-
torch.max(position_ids, dim=-1, keepdim=True)[0] + 1,
|
1346
|
-
],
|
1347
|
-
dim=-1,
|
1348
|
-
)
|
1349
|
-
else:
|
1350
|
-
position_ids = None
|
1351
|
-
|
1352
1032
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
1353
1033
|
if inputs_embeds is not None and past_key_values is None:
|
1354
1034
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
@@ -1386,7 +1066,7 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1386
1066
|
{
|
1387
1067
|
"past_key_values": past_key_values,
|
1388
1068
|
"use_cache": kwargs.get("use_cache"),
|
1389
|
-
"
|
1069
|
+
"ages": ages,
|
1390
1070
|
"attention_mask": attention_mask,
|
1391
1071
|
"random_vectors": random_vectors,
|
1392
1072
|
}
|
@@ -1396,12 +1076,12 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1396
1076
|
|
1397
1077
|
def motor_nll_loss(
|
1398
1078
|
self,
|
1399
|
-
|
1400
|
-
|
1401
|
-
|
1402
|
-
|
1403
|
-
|
1404
|
-
|
1079
|
+
hidden_states,
|
1080
|
+
motor_tte_times,
|
1081
|
+
motor_tte_event_indicators,
|
1082
|
+
motor_tte_task_indicators,
|
1083
|
+
motor_tte_masks,
|
1084
|
+
motor_end_index,
|
1405
1085
|
):
|
1406
1086
|
"""
|
1407
1087
|
Computes the negative log-likelihood (NLL) loss using the LogNormal distribution.
|
@@ -1409,58 +1089,62 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1409
1089
|
for modeling time-to-event data at each visit.
|
1410
1090
|
|
1411
1091
|
Args:
|
1412
|
-
|
1413
|
-
|
1414
|
-
|
1415
|
-
|
1416
|
-
|
1417
|
-
|
1418
|
-
|
1092
|
+
hidden_states (Tensor): Hidden representations for sequence tokens [num_of_concepts, hidden_dim].
|
1093
|
+
motor_tte_times (Tensor): Raw time-to-event durations [B, T, motor_vocab_size] (flattened).
|
1094
|
+
motor_tte_task_indicators: (Tensor): Bool indicators (True if included, False if not included).
|
1095
|
+
motor_tte_event_indicators (Tensor): Binary indicators (1 if censored, 0 if event occurred).
|
1096
|
+
motor_tte_masks (Tensor): Binary indicators whether the prediction should be masked
|
1097
|
+
(1 if not masked, 0 if masked).
|
1098
|
+
motor_end_index (Tensor): Tensor indicating the number of valid [VE] tokens in the batch.
|
1419
1099
|
|
1420
1100
|
Returns:
|
1421
1101
|
Tensor: Scalar loss value (mean negative log-likelihood).
|
1422
1102
|
"""
|
1423
|
-
|
1424
|
-
|
1103
|
+
motor_end_index = motor_end_index.sum().item()
|
1104
|
+
motor_tte_times = motor_tte_times.view(
|
1425
1105
|
(-1, self.config.motor_num_time_pieces, self.config.motor_tte_vocab_size)
|
1426
|
-
)[:
|
1427
|
-
|
1428
|
-
(-1, self.config.motor_num_time_pieces, self.config.motor_tte_vocab_size)
|
1429
|
-
)[:batch_motor_end_index]
|
1430
|
-
motor_time_to_event_to_include = motor_time_to_event_to_include.flatten()[
|
1431
|
-
:batch_motor_end_index
|
1432
|
-
]
|
1433
|
-
motor_time_indicators = motor_time_indicators.view(
|
1106
|
+
)[:motor_end_index].clamp(min=1e-3)
|
1107
|
+
motor_tte_event_indicators = motor_tte_event_indicators.reshape(
|
1434
1108
|
(-1, self.config.motor_num_time_pieces, self.config.motor_tte_vocab_size)
|
1435
|
-
)[:
|
1436
|
-
|
1109
|
+
)[:motor_end_index]
|
1110
|
+
# motor_tte_masks = motor_tte_masks.view(
|
1111
|
+
# (-1, self.config.motor_num_time_pieces, self.config.motor_tte_vocab_size)
|
1112
|
+
# )[:motor_end_index]
|
1113
|
+
|
1114
|
+
tte_features = hidden_states[motor_tte_task_indicators].view(
|
1115
|
+
(-1, self.config.n_embd)
|
1116
|
+
)
|
1117
|
+
|
1118
|
+
assert tte_features.shape[0] == motor_tte_times.shape[0], (
|
1437
1119
|
"The number of VE tokens in the labels needs to match up "
|
1438
1120
|
"with the first dimension of motor_time_to_event_vectors. "
|
1439
|
-
f"Received ve_token_features.shape[0]: {
|
1440
|
-
f"motor_time_to_event_vectors.shape[0]: {
|
1121
|
+
f"Received ve_token_features.shape[0]: {tte_features.shape[0]}, "
|
1122
|
+
f"motor_time_to_event_vectors.shape[0]: {motor_tte_times.shape[0]}"
|
1441
1123
|
)
|
1442
|
-
motor_time_to_event_vectors = motor_time_to_event_vectors[
|
1443
|
-
motor_time_to_event_to_include
|
1444
|
-
]
|
1445
|
-
motor_event_indicators = motor_event_indicators[motor_time_to_event_to_include]
|
1446
|
-
motor_time_indicators = motor_time_indicators[motor_time_to_event_to_include]
|
1447
|
-
ve_token_features = ve_token_features[motor_time_to_event_to_include]
|
1448
1124
|
|
1449
1125
|
# Get Exponential parameters from model
|
1450
|
-
|
1451
|
-
# (num_visits_in_batch, num_of_pieces, motor_vocab_size)
|
1452
|
-
dist = Exponential(lambda_p.clamp(min=1e-3))
|
1126
|
+
time_dependent_logits = self.motor_tte(tte_features)
|
1453
1127
|
|
1454
1128
|
# Compute event loss
|
1455
|
-
|
1456
|
-
|
1457
|
-
|
1458
|
-
|
1459
|
-
|
1460
|
-
)
|
1129
|
+
# Calculate the accumulative hazard
|
1130
|
+
# exp(-sum_{j} lambda_j)
|
1131
|
+
survival_loss = torch.exp2(time_dependent_logits + motor_tte_times).mean()
|
1132
|
+
event_loss = (
|
1133
|
+
-math.log(2)
|
1134
|
+
* torch.where(motor_tte_event_indicators, time_dependent_logits, 0).mean()
|
1461
1135
|
)
|
1462
|
-
|
1463
|
-
|
1136
|
+
|
1137
|
+
# survival_loss = (
|
1138
|
+
# torch.where(motor_tte_masks, lambda_p * motor_tte_times, 0)
|
1139
|
+
# .sum(dim=1)
|
1140
|
+
# .mean()
|
1141
|
+
# )
|
1142
|
+
# event_loss = (
|
1143
|
+
# -torch.where(motor_tte_event_indicators, torch.log(lambda_p), 0)
|
1144
|
+
# .sum(dim=1)
|
1145
|
+
# .mean()
|
1146
|
+
# )
|
1147
|
+
return survival_loss + event_loss
|
1464
1148
|
|
1465
1149
|
def forward(
|
1466
1150
|
self,
|
@@ -1469,7 +1153,6 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1469
1153
|
values: Optional[torch.LongTensor] = None,
|
1470
1154
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
1471
1155
|
attention_mask: Optional[torch.FloatTensor] = None,
|
1472
|
-
position_ids: Optional[torch.LongTensor] = None,
|
1473
1156
|
head_mask: Optional[torch.FloatTensor] = None,
|
1474
1157
|
random_vectors: Optional[torch.FloatTensor] = None,
|
1475
1158
|
labels: Optional[torch.LongTensor] = None,
|
@@ -1478,21 +1161,23 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1478
1161
|
time_to_visits: Optional[torch.FloatTensor] = None,
|
1479
1162
|
time_token_indicators: Optional[torch.BoolTensor] = None,
|
1480
1163
|
sub_time_tokens: Optional[torch.LongTensor] = None,
|
1481
|
-
|
1482
|
-
|
1483
|
-
|
1484
|
-
|
1164
|
+
motor_tte_times: Optional[torch.FloatTensor] = None,
|
1165
|
+
motor_tte_event_indicators: Optional[torch.BoolTensor] = None,
|
1166
|
+
motor_tte_task_indicators: Optional[torch.BoolTensor] = None,
|
1167
|
+
motor_tte_masks: Optional[torch.BoolTensor] = None,
|
1485
1168
|
motor_end_index: Optional[torch.LongTensor] = None,
|
1486
1169
|
use_cache: Optional[bool] = None,
|
1487
1170
|
output_attentions: Optional[bool] = None,
|
1488
1171
|
output_hidden_states: Optional[bool] = None,
|
1489
1172
|
return_dict: Optional[bool] = None,
|
1173
|
+
ages: Optional[torch.FloatTensor] = None,
|
1174
|
+
epoch_times: Optional[torch.FloatTensor] = None,
|
1490
1175
|
) -> Union[Tuple, CehrGptCausalLMOutput]:
|
1491
1176
|
r"""
|
1492
1177
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1493
|
-
|
1494
|
-
|
1495
|
-
|
1178
|
+
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
1179
|
+
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
1180
|
+
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
1496
1181
|
"""
|
1497
1182
|
return_dict = (
|
1498
1183
|
return_dict if return_dict is not None else self.config.use_return_dict
|
@@ -1504,7 +1189,7 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1504
1189
|
values=values,
|
1505
1190
|
past_key_values=past_key_values,
|
1506
1191
|
attention_mask=attention_mask,
|
1507
|
-
position_ids=
|
1192
|
+
position_ids=ages,
|
1508
1193
|
random_vectors=random_vectors,
|
1509
1194
|
head_mask=head_mask,
|
1510
1195
|
use_cache=use_cache,
|
@@ -1613,23 +1298,19 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1613
1298
|
|
1614
1299
|
if (
|
1615
1300
|
self.config.include_motor_time_to_event
|
1616
|
-
and
|
1617
|
-
and
|
1618
|
-
and
|
1619
|
-
and
|
1301
|
+
and motor_tte_times is not None
|
1302
|
+
and motor_tte_event_indicators is not None
|
1303
|
+
and motor_tte_task_indicators is not None
|
1304
|
+
and motor_tte_masks is not None
|
1620
1305
|
and motor_end_index is not None
|
1621
1306
|
):
|
1622
|
-
ve_token_id_indices = labels == self.config.ve_token_id
|
1623
|
-
ve_token_features = hidden_states[ve_token_id_indices]
|
1624
|
-
# Get rid of the last VE features because it's already reached the end of the patient sequence and
|
1625
|
-
# there is nothing to predict.
|
1626
1307
|
motor_tte_loss = self.motor_nll_loss(
|
1627
|
-
|
1628
|
-
|
1629
|
-
|
1630
|
-
|
1631
|
-
|
1632
|
-
|
1308
|
+
hidden_states=hidden_states,
|
1309
|
+
motor_tte_times=motor_tte_times,
|
1310
|
+
motor_tte_event_indicators=motor_tte_event_indicators,
|
1311
|
+
motor_tte_task_indicators=motor_tte_task_indicators,
|
1312
|
+
motor_tte_masks=motor_tte_masks,
|
1313
|
+
motor_end_index=motor_end_index,
|
1633
1314
|
)
|
1634
1315
|
loss += motor_tte_loss * self.config.motor_time_to_event_weight
|
1635
1316
|
|
@@ -1835,6 +1516,15 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1835
1516
|
else self.generation_config.return_dict_in_generate
|
1836
1517
|
)
|
1837
1518
|
|
1519
|
+
if "cehrgpt_tokenizer" not in model_kwargs:
|
1520
|
+
raise RuntimeError(
|
1521
|
+
"The cehr-gpt tokenizer must be provided to the "
|
1522
|
+
"model.generate(..., cehrgpt_tokenizer=cehrgpt_tokenizer)"
|
1523
|
+
)
|
1524
|
+
|
1525
|
+
# Remove this from the model_kwargs and will pass it to other functions explicitly
|
1526
|
+
cehrgpt_tokenizer = model_kwargs.pop("cehrgpt_tokenizer")
|
1527
|
+
|
1838
1528
|
# init attention / hidden states / scores tuples
|
1839
1529
|
scores = () if (return_dict_in_generate and output_scores) else None
|
1840
1530
|
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
@@ -1858,18 +1548,11 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1858
1548
|
batch_size, dtype=torch.long, device=input_ids.device
|
1859
1549
|
)
|
1860
1550
|
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
|
1861
|
-
#
|
1862
|
-
|
1863
|
-
lab_token_ids
|
1864
|
-
|
1865
|
-
|
1866
|
-
)
|
1867
|
-
else:
|
1868
|
-
lab_token_ids = torch.tensor(
|
1869
|
-
[] if self.config.lab_token_ids is None else self.config.lab_token_ids,
|
1870
|
-
dtype=torch.int32,
|
1871
|
-
)
|
1872
|
-
|
1551
|
+
# Getting the lab token ids
|
1552
|
+
lab_token_ids = torch.tensor(
|
1553
|
+
cehrgpt_tokenizer.lab_token_ids,
|
1554
|
+
dtype=torch.int32,
|
1555
|
+
)
|
1873
1556
|
if model_kwargs.get("value_indicators", None) is not None:
|
1874
1557
|
value_indicators = model_kwargs.get("value_indicators")
|
1875
1558
|
else:
|
@@ -1895,11 +1578,33 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1895
1578
|
model_kwargs["random_vectors"] = None
|
1896
1579
|
model_kwargs["value_indicators"] = value_indicators
|
1897
1580
|
model_kwargs["values"] = values
|
1581
|
+
|
1582
|
+
# A variable to keep track of time and initialize it to zero
|
1583
|
+
batched_time_delta = np.zeros((batch_size,), dtype=np.float32)
|
1584
|
+
batched_ages = model_kwargs.get("ages", None)
|
1585
|
+
if batched_ages is None:
|
1586
|
+
batched_ages = []
|
1587
|
+
for token_ids in input_ids.detach().cpu():
|
1588
|
+
concept_ids = cehrgpt_tokenizer.decode(
|
1589
|
+
token_ids.numpy(), skip_special_tokens=False
|
1590
|
+
)
|
1591
|
+
batched_ages.append(construct_age_sequence(concept_ids))
|
1592
|
+
# Turn this to a numpy array for easy manipulation
|
1593
|
+
batched_ages = np.asarray(batched_ages)
|
1594
|
+
else:
|
1595
|
+
batched_ages = batched_ages.cpu().numpy()
|
1596
|
+
# This is the base to which we will add the time delta
|
1597
|
+
base_ages = np.asarray([ages[-1] for ages in batched_ages])
|
1598
|
+
# Update the keyword arguments for the prepare_inputs_for_generation
|
1599
|
+
model_kwargs["ages"] = torch.tensor(batched_ages).to(input_ids.device)
|
1600
|
+
|
1898
1601
|
while self._has_unfinished_sequences(
|
1899
1602
|
this_peer_finished, synced_gpus, device=input_ids.device
|
1900
1603
|
):
|
1901
1604
|
# prepare model inputs
|
1902
|
-
model_inputs = self.prepare_inputs_for_generation(
|
1605
|
+
model_inputs = self.prepare_inputs_for_generation(
|
1606
|
+
input_ids, cehrgpt_tokenizer, **model_kwargs
|
1607
|
+
)
|
1903
1608
|
|
1904
1609
|
# forward pass to get next token
|
1905
1610
|
outputs = self(
|
@@ -1944,6 +1649,22 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1944
1649
|
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
1945
1650
|
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
1946
1651
|
|
1652
|
+
# TODO: decode to get time tokens and recalculate the age at this time step
|
1653
|
+
# Look for a potential time token
|
1654
|
+
for batch_i, next_concept_id in enumerate(
|
1655
|
+
cehrgpt_tokenizer.decode(
|
1656
|
+
next_tokens.detach().cpu().numpy(), skip_special_tokens=False
|
1657
|
+
)
|
1658
|
+
):
|
1659
|
+
if is_att_token(next_concept_id):
|
1660
|
+
batched_time_delta[batch_i] += extract_time_interval_in_days(
|
1661
|
+
next_concept_id
|
1662
|
+
)
|
1663
|
+
|
1664
|
+
next_age = (base_ages + batched_time_delta // 365).astype(int)[..., None]
|
1665
|
+
batched_ages = np.concatenate([batched_ages, next_age], axis=-1)
|
1666
|
+
model_kwargs["ages"] = torch.tensor(batched_ages).to(input_ids.device)
|
1667
|
+
|
1947
1668
|
# finished sentences should have their next token be a padding token
|
1948
1669
|
if eos_token_id is not None:
|
1949
1670
|
if pad_token_id is None:
|
@@ -1979,6 +1700,7 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1979
1700
|
|
1980
1701
|
if streamer is not None:
|
1981
1702
|
streamer.put(next_tokens.cpu())
|
1703
|
+
|
1982
1704
|
model_kwargs = self._update_model_kwargs_for_generation(
|
1983
1705
|
outputs,
|
1984
1706
|
model_kwargs,
|
@@ -2023,7 +1745,7 @@ class FocalLoss(nn.Module):
|
|
2023
1745
|
self.reduction = reduction
|
2024
1746
|
|
2025
1747
|
def forward(self, logits, targets):
|
2026
|
-
bce_loss =
|
1748
|
+
bce_loss = f.binary_cross_entropy_with_logits(logits, targets, reduction="none")
|
2027
1749
|
probs = torch.sigmoid(logits)
|
2028
1750
|
pt = torch.where(targets == 1, probs, 1 - probs)
|
2029
1751
|
focal_term = (1 - pt) ** self.gamma
|
@@ -2105,12 +1827,13 @@ class CehrGptForClassification(CEHRGPTPreTrainedModel):
|
|
2105
1827
|
values: Optional[torch.LongTensor] = None,
|
2106
1828
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
2107
1829
|
attention_mask: Optional[torch.FloatTensor] = None,
|
2108
|
-
position_ids: Optional[torch.LongTensor] = None,
|
2109
1830
|
head_mask: Optional[torch.FloatTensor] = None,
|
2110
1831
|
use_cache: Optional[bool] = None,
|
2111
1832
|
output_attentions: Optional[bool] = None,
|
2112
1833
|
output_hidden_states: Optional[bool] = None,
|
2113
1834
|
return_dict: Optional[bool] = None,
|
1835
|
+
ages: Optional[torch.FloatTensor] = None,
|
1836
|
+
epoch_times: Optional[torch.FloatTensor] = None,
|
2114
1837
|
**kwargs,
|
2115
1838
|
) -> CehrGptSequenceClassifierOutput:
|
2116
1839
|
|
@@ -2120,12 +1843,12 @@ class CehrGptForClassification(CEHRGPTPreTrainedModel):
|
|
2120
1843
|
values=values,
|
2121
1844
|
past_key_values=past_key_values,
|
2122
1845
|
attention_mask=attention_mask,
|
2123
|
-
position_ids=position_ids,
|
2124
1846
|
head_mask=head_mask,
|
2125
1847
|
use_cache=use_cache,
|
2126
1848
|
output_attentions=output_attentions,
|
2127
1849
|
output_hidden_states=output_hidden_states,
|
2128
1850
|
return_dict=return_dict,
|
1851
|
+
position_ids=ages,
|
2129
1852
|
)
|
2130
1853
|
|
2131
1854
|
if is_sample_pack(attention_mask):
|