cehrgpt 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/irregularity.py +36 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +1 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +454 -68
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +232 -17
- cehrgpt/data/sample_packing_sampler.py +36 -6
- cehrgpt/generation/cehrgpt_conditional_generation.py +314 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +15 -3
- cehrgpt/generation/omop_converter_batch.py +32 -2
- cehrgpt/gpt_utils.py +20 -2
- cehrgpt/models/config.py +25 -0
- cehrgpt/models/hf_cehrgpt.py +244 -39
- cehrgpt/models/hf_modeling_outputs.py +1 -0
- cehrgpt/models/special_tokens.py +1 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +354 -71
- cehrgpt/runners/data_utils.py +131 -5
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +84 -51
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +59 -7
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +60 -0
- cehrgpt/runners/hyperparameter_search_util.py +6 -7
- cehrgpt/runners/sample_packing_trainer.py +17 -0
- cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
- cehrgpt/time_to_event/time_to_event_model.py +2 -13
- cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +80 -62
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/METADATA +102 -7
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/RECORD +29 -26
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/WHEEL +1 -1
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/licenses/LICENSE +0 -0
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/top_level.txt +0 -0
cehrgpt/models/hf_cehrgpt.py
CHANGED
@@ -6,7 +6,7 @@ import numpy as np
|
|
6
6
|
import torch
|
7
7
|
import torch.nn.functional as f
|
8
8
|
from torch import nn
|
9
|
-
from torch.distributions import
|
9
|
+
from torch.distributions import Exponential, Gamma
|
10
10
|
from torch.nn import CrossEntropyLoss
|
11
11
|
from torch.nn import functional as F
|
12
12
|
from transformers import PreTrainedModel
|
@@ -102,7 +102,9 @@ def is_sample_pack(attention_mask: torch.Tensor) -> bool:
|
|
102
102
|
attention_mask = attention_mask.flip(dims=[1])
|
103
103
|
|
104
104
|
nonzero_counts = attention_mask.sum(dim=1)
|
105
|
-
max_token_positions = torch.argmax(
|
105
|
+
max_token_positions = torch.argmax(
|
106
|
+
attention_mask.to(torch.int32).flip(dims=[1]), dim=1
|
107
|
+
)
|
106
108
|
max_indices = attention_mask.shape[1] - 1 - max_token_positions
|
107
109
|
return torch.any(nonzero_counts < (max_indices + 1)).item()
|
108
110
|
|
@@ -362,9 +364,37 @@ class GPT2FlashAttention(GPT2Attention):
|
|
362
364
|
)
|
363
365
|
|
364
366
|
|
365
|
-
class
|
367
|
+
class MotorTaskHead(nn.Module):
|
368
|
+
def __init__(self, input_dim, motor_tte_vocab_size, motor_num_time_pieces):
|
369
|
+
super(MotorTaskHead, self).__init__()
|
370
|
+
self.input_dim = input_dim
|
371
|
+
self.motor_tte_vocab_size = motor_tte_vocab_size
|
372
|
+
self.motor_num_time_pieces = motor_num_time_pieces
|
373
|
+
self.linear = nn.Sequential(
|
374
|
+
nn.Linear(input_dim, input_dim // 2),
|
375
|
+
gelu_new,
|
376
|
+
nn.Linear(
|
377
|
+
input_dim // 2, motor_tte_vocab_size * self.motor_num_time_pieces
|
378
|
+
),
|
379
|
+
)
|
380
|
+
|
381
|
+
def forward(self, x):
|
382
|
+
# Ensure scale is positive
|
383
|
+
length = x.shape[0]
|
384
|
+
# (num_visits_in_batch, motor_tte_vocab_size * motor_num_time_pieces)
|
385
|
+
lambda_p = f.softplus(self.linear(x))
|
386
|
+
# Check for NaN values
|
387
|
+
if torch.isnan(lambda_p).any():
|
388
|
+
logger.warning(f"NaN values found in scale_param. x: {x}")
|
389
|
+
# (num_visits_in_batch, motor_num_time_pieces, motor_tte_vocab_size,)
|
390
|
+
return lambda_p.view(
|
391
|
+
length, self.motor_num_time_pieces, self.motor_tte_vocab_size
|
392
|
+
)
|
393
|
+
|
394
|
+
|
395
|
+
class VisitTimeToEventHead(nn.Module):
|
366
396
|
def __init__(self, input_dim):
|
367
|
-
super(
|
397
|
+
super(VisitTimeToEventHead, self).__init__()
|
368
398
|
self.linear1 = nn.Sequential(
|
369
399
|
nn.Linear(input_dim, input_dim // 2), gelu_new, nn.Linear(input_dim // 2, 1)
|
370
400
|
)
|
@@ -661,32 +691,33 @@ class CEHRGPTPreTrainedModel(PreTrainedModel):
|
|
661
691
|
hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
|
662
692
|
)
|
663
693
|
wpe = self.get_position_embeddings()
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
# initialize all new embeddings (in particular added tokens)
|
674
|
-
self._init_weights(new_embeddings)
|
675
|
-
if is_deepspeed_zero3_enabled() and not is_quantized:
|
676
|
-
import deepspeed
|
694
|
+
if wpe is not None:
|
695
|
+
max_position, embed_dim = wpe.weight.shape
|
696
|
+
if new_num_position_embeddings > max_position:
|
697
|
+
new_embeddings = nn.Embedding(
|
698
|
+
new_num_position_embeddings,
|
699
|
+
embed_dim,
|
700
|
+
device=wpe.weight.device,
|
701
|
+
dtype=wpe.weight.dtype,
|
702
|
+
)
|
677
703
|
|
678
|
-
|
679
|
-
|
704
|
+
# initialize all new embeddings (in particular added tokens)
|
705
|
+
self._init_weights(new_embeddings)
|
706
|
+
if is_deepspeed_zero3_enabled() and not is_quantized:
|
707
|
+
import deepspeed
|
708
|
+
|
709
|
+
params = [wpe.weight, new_embeddings.weight]
|
710
|
+
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
|
711
|
+
new_embeddings.weight.data[:max_position, :] = (
|
712
|
+
wpe.weight.data[:max_position, :]
|
713
|
+
)
|
714
|
+
else:
|
680
715
|
new_embeddings.weight.data[:max_position, :] = wpe.weight.data[
|
681
716
|
:max_position, :
|
682
717
|
]
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
]
|
687
|
-
self.set_position_embeddings(new_embeddings)
|
688
|
-
self.config.max_position_embeddings = new_num_position_embeddings
|
689
|
-
self.update_attn_bias(new_num_position_embeddings)
|
718
|
+
self.set_position_embeddings(new_embeddings)
|
719
|
+
self.config.max_position_embeddings = new_num_position_embeddings
|
720
|
+
self.update_attn_bias(new_num_position_embeddings)
|
690
721
|
|
691
722
|
|
692
723
|
class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
@@ -740,6 +771,10 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
740
771
|
)
|
741
772
|
self.update_attn_bias(self.config.sample_packing_max_positions)
|
742
773
|
|
774
|
+
def enable_position_embeddings(self):
|
775
|
+
self.wpe = nn.Embedding(self.config.max_position_embeddings, self.embed_dim)
|
776
|
+
self.config.exclude_position_ids = False
|
777
|
+
|
743
778
|
def initialize_pretrained_embeddings(self):
|
744
779
|
layers = [
|
745
780
|
nn.Embedding(self.config.vocab_size, self.config.pretrained_embedding_dim),
|
@@ -1043,7 +1078,7 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
1043
1078
|
)
|
1044
1079
|
|
1045
1080
|
if not self.exclude_position_ids:
|
1046
|
-
position_embeds = self.wpe(position_ids)
|
1081
|
+
position_embeds = self.wpe(position_ids).to(input_embeddings.dtype)
|
1047
1082
|
hidden_states = input_embeddings + position_embeds
|
1048
1083
|
else:
|
1049
1084
|
hidden_states = input_embeddings
|
@@ -1152,7 +1187,7 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1152
1187
|
super().__init__(config)
|
1153
1188
|
self.cehrgpt = CEHRGPT2Model(config)
|
1154
1189
|
if self.config.include_ttv_prediction:
|
1155
|
-
self.tte_head =
|
1190
|
+
self.tte_head = VisitTimeToEventHead(config.n_embd)
|
1156
1191
|
|
1157
1192
|
if self.config.use_sub_time_tokenization:
|
1158
1193
|
self.time_token_lm_head = nn.Linear(
|
@@ -1165,6 +1200,11 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1165
1200
|
config.n_embd, config.value_vocab_size, bias=False
|
1166
1201
|
)
|
1167
1202
|
|
1203
|
+
if self.config.include_motor_time_to_event:
|
1204
|
+
self.motor_tte = MotorTaskHead(
|
1205
|
+
config.n_embd, config.motor_tte_vocab_size, config.motor_num_time_pieces
|
1206
|
+
)
|
1207
|
+
|
1168
1208
|
# Model parallel
|
1169
1209
|
self.model_parallel = False
|
1170
1210
|
self.device_map = None
|
@@ -1192,6 +1232,8 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1192
1232
|
self.value_head = self.value_head.to(self.cehrgpt.first_device)
|
1193
1233
|
if self.config.include_ttv_prediction:
|
1194
1234
|
self.tte_head = self.tte_head.to(self.cehrgpt.first_device)
|
1235
|
+
if self.config.include_motor_time_to_event:
|
1236
|
+
self.motor_tte = self.motor_tte.to(self.cehrgpt.first_device)
|
1195
1237
|
self.model_parallel = True
|
1196
1238
|
|
1197
1239
|
def deparallelize(self):
|
@@ -1206,6 +1248,8 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1206
1248
|
self.value_head = self.value_head.to("cpu")
|
1207
1249
|
if self.config.include_ttv_prediction:
|
1208
1250
|
self.tte_head = self.tte_head.to("cpu")
|
1251
|
+
if self.config.include_motor_time_to_event:
|
1252
|
+
self.motor_tte = self.motor_tte.to("cpu")
|
1209
1253
|
self.model_parallel = False
|
1210
1254
|
torch.cuda.empty_cache()
|
1211
1255
|
|
@@ -1233,6 +1277,28 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1233
1277
|
def update_attn_bias(self, max_position_embeddings: int):
|
1234
1278
|
self.cehrgpt.update_attn_bias(max_position_embeddings)
|
1235
1279
|
|
1280
|
+
def update_motor_tte_vocab_size(
|
1281
|
+
self, motor_tte_vocab_size: Optional[int] = None
|
1282
|
+
) -> None:
|
1283
|
+
update_motor_tte_layer = False
|
1284
|
+
if motor_tte_vocab_size and motor_tte_vocab_size > 0:
|
1285
|
+
if self.config.include_motor_time_to_event:
|
1286
|
+
if self.config.motor_tte_vocab_size != motor_tte_vocab_size:
|
1287
|
+
self.config.include_motor_time_to_event = True
|
1288
|
+
self.config.motor_tte_vocab_size = motor_tte_vocab_size
|
1289
|
+
update_motor_tte_layer = True
|
1290
|
+
else:
|
1291
|
+
self.config.include_motor_time_to_event = True
|
1292
|
+
self.config.motor_tte_vocab_size = motor_tte_vocab_size
|
1293
|
+
update_motor_tte_layer = True
|
1294
|
+
|
1295
|
+
if update_motor_tte_layer:
|
1296
|
+
self.motor_tte = MotorTaskHead(
|
1297
|
+
self.config.n_embd,
|
1298
|
+
self.config.motor_tte_vocab_size,
|
1299
|
+
self.config.motor_num_time_pieces,
|
1300
|
+
)
|
1301
|
+
|
1236
1302
|
def prepare_inputs_for_generation(
|
1237
1303
|
self,
|
1238
1304
|
input_ids,
|
@@ -1328,6 +1394,74 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1328
1394
|
|
1329
1395
|
return model_inputs
|
1330
1396
|
|
1397
|
+
def motor_nll_loss(
|
1398
|
+
self,
|
1399
|
+
ve_token_features,
|
1400
|
+
motor_time_to_event_vectors,
|
1401
|
+
motor_event_indicators,
|
1402
|
+
motor_time_to_event_to_include,
|
1403
|
+
motor_time_indicators,
|
1404
|
+
batch_motor_end_index,
|
1405
|
+
):
|
1406
|
+
"""
|
1407
|
+
Computes the negative log-likelihood (NLL) loss using the LogNormal distribution.
|
1408
|
+
|
1409
|
+
for modeling time-to-event data at each visit.
|
1410
|
+
|
1411
|
+
Args:
|
1412
|
+
ve_token_features (Tensor): Hidden representations for the [VE] tokens [num_visits, hidden_dim].
|
1413
|
+
motor_time_to_event_vectors (Tensor): Raw time-to-event durations [B, T, motor_vocab_size] (flattened).
|
1414
|
+
motor_time_to_event_to_include: (Tensor): Bool indicators (True if included, False if not included).
|
1415
|
+
motor_event_indicators (Tensor): Binary indicators (1 if censored, 0 if event occurred).
|
1416
|
+
motor_time_indicators (Tensor): Binary indicators whether the time occurs in the current
|
1417
|
+
time bucket (1 if censored, 0 if event occurred).
|
1418
|
+
batch_motor_end_index (Tensor): Tensor indicating the number of valid [VE] tokens in the batch.
|
1419
|
+
|
1420
|
+
Returns:
|
1421
|
+
Tensor: Scalar loss value (mean negative log-likelihood).
|
1422
|
+
"""
|
1423
|
+
batch_motor_end_index = batch_motor_end_index.sum().item()
|
1424
|
+
motor_time_to_event_vectors = motor_time_to_event_vectors.view(
|
1425
|
+
(-1, self.config.motor_num_time_pieces, self.config.motor_tte_vocab_size)
|
1426
|
+
)[:batch_motor_end_index].clamp(min=1e-3)
|
1427
|
+
motor_event_indicators = motor_event_indicators.reshape(
|
1428
|
+
(-1, self.config.motor_num_time_pieces, self.config.motor_tte_vocab_size)
|
1429
|
+
)[:batch_motor_end_index]
|
1430
|
+
motor_time_to_event_to_include = motor_time_to_event_to_include.flatten()[
|
1431
|
+
:batch_motor_end_index
|
1432
|
+
]
|
1433
|
+
motor_time_indicators = motor_time_indicators.view(
|
1434
|
+
(-1, self.config.motor_num_time_pieces, self.config.motor_tte_vocab_size)
|
1435
|
+
)[:batch_motor_end_index]
|
1436
|
+
assert ve_token_features.shape[0] == motor_time_to_event_vectors.shape[0], (
|
1437
|
+
"The number of VE tokens in the labels needs to match up "
|
1438
|
+
"with the first dimension of motor_time_to_event_vectors. "
|
1439
|
+
f"Received ve_token_features.shape[0]: {ve_token_features.shape[0]}, "
|
1440
|
+
f"motor_time_to_event_vectors.shape[0]: {motor_time_to_event_vectors.shape[0]}"
|
1441
|
+
)
|
1442
|
+
motor_time_to_event_vectors = motor_time_to_event_vectors[
|
1443
|
+
motor_time_to_event_to_include
|
1444
|
+
]
|
1445
|
+
motor_event_indicators = motor_event_indicators[motor_time_to_event_to_include]
|
1446
|
+
motor_time_indicators = motor_time_indicators[motor_time_to_event_to_include]
|
1447
|
+
ve_token_features = ve_token_features[motor_time_to_event_to_include]
|
1448
|
+
|
1449
|
+
# Get Exponential parameters from model
|
1450
|
+
lambda_p = self.motor_tte(ve_token_features)
|
1451
|
+
# (num_visits_in_batch, num_of_pieces, motor_vocab_size)
|
1452
|
+
dist = Exponential(lambda_p.clamp(min=1e-3))
|
1453
|
+
|
1454
|
+
# Compute event loss
|
1455
|
+
tte_loss = torch.where(
|
1456
|
+
motor_event_indicators,
|
1457
|
+
-dist.log_prob(motor_time_to_event_vectors),
|
1458
|
+
-torch.log(
|
1459
|
+
1 - dist.cdf(motor_time_to_event_vectors).clamp(max=1 - 1e-6) + 1e-6
|
1460
|
+
),
|
1461
|
+
)
|
1462
|
+
tte_loss = torch.where(motor_time_indicators, tte_loss, 0.0)
|
1463
|
+
return torch.mean(tte_loss)
|
1464
|
+
|
1331
1465
|
def forward(
|
1332
1466
|
self,
|
1333
1467
|
input_ids: Optional[torch.LongTensor] = None,
|
@@ -1344,6 +1478,11 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1344
1478
|
time_to_visits: Optional[torch.FloatTensor] = None,
|
1345
1479
|
time_token_indicators: Optional[torch.BoolTensor] = None,
|
1346
1480
|
sub_time_tokens: Optional[torch.LongTensor] = None,
|
1481
|
+
motor_time_to_event_vectors: Optional[torch.FloatTensor] = None,
|
1482
|
+
motor_event_indicators: Optional[torch.BoolTensor] = None,
|
1483
|
+
motor_time_to_event_to_include: Optional[torch.BoolTensor] = None,
|
1484
|
+
motor_time_indicators: Optional[torch.BoolTensor] = None,
|
1485
|
+
motor_end_index: Optional[torch.LongTensor] = None,
|
1347
1486
|
use_cache: Optional[bool] = None,
|
1348
1487
|
output_attentions: Optional[bool] = None,
|
1349
1488
|
output_hidden_states: Optional[bool] = None,
|
@@ -1403,6 +1542,8 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1403
1542
|
time_token_loss = None
|
1404
1543
|
time_to_visit_loss = None
|
1405
1544
|
token_value_loss = None
|
1545
|
+
motor_tte_loss = None
|
1546
|
+
|
1406
1547
|
if labels is not None:
|
1407
1548
|
# move labels to correct device to enable model parallelism
|
1408
1549
|
labels = labels.to(lm_logits.device)
|
@@ -1470,9 +1611,35 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1470
1611
|
entropy_penalty = entropy.sum() / total_num_tokens
|
1471
1612
|
loss += entropy_penalty * self.cehrgpt.config.entropy_penalty_alpha
|
1472
1613
|
|
1614
|
+
if (
|
1615
|
+
self.config.include_motor_time_to_event
|
1616
|
+
and motor_time_to_event_vectors is not None
|
1617
|
+
and motor_event_indicators is not None
|
1618
|
+
and motor_time_to_event_to_include is not None
|
1619
|
+
and motor_time_indicators is not None
|
1620
|
+
and motor_end_index is not None
|
1621
|
+
):
|
1622
|
+
ve_token_id_indices = labels == self.config.ve_token_id
|
1623
|
+
ve_token_features = hidden_states[ve_token_id_indices]
|
1624
|
+
# Get rid of the last VE features because it's already reached the end of the patient sequence and
|
1625
|
+
# there is nothing to predict.
|
1626
|
+
motor_tte_loss = self.motor_nll_loss(
|
1627
|
+
ve_token_features=ve_token_features,
|
1628
|
+
motor_time_to_event_vectors=motor_time_to_event_vectors,
|
1629
|
+
motor_event_indicators=motor_event_indicators,
|
1630
|
+
motor_time_to_event_to_include=motor_time_to_event_to_include,
|
1631
|
+
motor_time_indicators=motor_time_indicators,
|
1632
|
+
batch_motor_end_index=motor_end_index,
|
1633
|
+
)
|
1634
|
+
loss += motor_tte_loss * self.config.motor_time_to_event_weight
|
1635
|
+
|
1473
1636
|
# We add another loss term when use_sub_time_tokenization is enabled, we need to recover the sub time token
|
1474
1637
|
# predictions for year/month/token
|
1475
|
-
if
|
1638
|
+
if (
|
1639
|
+
self.config.use_sub_time_tokenization
|
1640
|
+
and sub_time_tokens is not None
|
1641
|
+
and time_token_indicators is not None
|
1642
|
+
):
|
1476
1643
|
# Split the last dimensions into three parts
|
1477
1644
|
time_loss_fct = CrossEntropyLoss(reduction="none")
|
1478
1645
|
time_token_logits = self.time_token_lm_head(
|
@@ -1501,7 +1668,7 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1501
1668
|
time_token_loss = time_token_loss.sum() / total_num_tokens
|
1502
1669
|
loss += time_token_loss * self.config.time_token_loss_weight
|
1503
1670
|
|
1504
|
-
if time_to_visits is not None:
|
1671
|
+
if time_to_visits is not None and time_to_visits is not None:
|
1505
1672
|
# Get lambda and k parameters
|
1506
1673
|
lambda_param, k_param = self.tte_head(hidden_states)
|
1507
1674
|
|
@@ -1512,14 +1679,15 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1512
1679
|
|
1513
1680
|
# Move to the same device as lambda_param
|
1514
1681
|
shift_time_to_visits = shift_time_to_visits.to(lambda_param.device)
|
1515
|
-
|
1516
1682
|
time_to_visit_indicator = shift_time_to_visits >= 0
|
1517
1683
|
# Define the Gamma distribution
|
1518
1684
|
dist = Gamma(
|
1519
1685
|
shifted_k_param.squeeze(-1), shifted_lambda_param.squeeze(-1)
|
1520
1686
|
)
|
1521
1687
|
# Compute log-probs and apply the time_to_visit_indicator
|
1522
|
-
log_probs = dist.log_prob(
|
1688
|
+
log_probs = dist.log_prob(
|
1689
|
+
torch.clamp(shift_time_to_visits, min=1e-3) + 1e-6
|
1690
|
+
)
|
1523
1691
|
log_probs = torch.where(time_to_visit_indicator, log_probs, 0)
|
1524
1692
|
time_to_visit_loss = -log_probs.sum() / total_num_tokens
|
1525
1693
|
# Compute the loss
|
@@ -1564,6 +1732,7 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1564
1732
|
time_token_loss=time_token_loss,
|
1565
1733
|
time_to_visit_loss=time_to_visit_loss,
|
1566
1734
|
token_value_loss=token_value_loss,
|
1735
|
+
motor_tte_loss=motor_tte_loss,
|
1567
1736
|
)
|
1568
1737
|
|
1569
1738
|
@staticmethod
|
@@ -1681,6 +1850,7 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1681
1850
|
|
1682
1851
|
# keep track of which sequences are already finished
|
1683
1852
|
batch_size, cur_len = input_ids.shape
|
1853
|
+
model_kwargs["attention_mask"] = input_ids != pad_token_id
|
1684
1854
|
if "inputs_embeds" in model_kwargs:
|
1685
1855
|
cur_len = model_kwargs["inputs_embeds"].shape[1]
|
1686
1856
|
this_peer_finished = False
|
@@ -1699,11 +1869,19 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1699
1869
|
[] if self.config.lab_token_ids is None else self.config.lab_token_ids,
|
1700
1870
|
dtype=torch.int32,
|
1701
1871
|
)
|
1702
|
-
|
1703
|
-
|
1704
|
-
|
1705
|
-
|
1706
|
-
|
1872
|
+
|
1873
|
+
if model_kwargs.get("value_indicators", None) is not None:
|
1874
|
+
value_indicators = model_kwargs.get("value_indicators")
|
1875
|
+
else:
|
1876
|
+
value_indicators = torch.zeros_like(input_ids).to(torch.bool)
|
1877
|
+
|
1878
|
+
if model_kwargs.get("values", None) is not None:
|
1879
|
+
values = model_kwargs.get("values")
|
1880
|
+
else:
|
1881
|
+
values = torch.zeros_like(
|
1882
|
+
input_ids,
|
1883
|
+
dtype=torch.int32,
|
1884
|
+
)
|
1707
1885
|
# Generate initial random_vectors
|
1708
1886
|
if self.cehrgpt.config.causal_sfm:
|
1709
1887
|
model_kwargs["random_vectors"] = torch.rand(
|
@@ -1837,6 +2015,27 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1837
2015
|
)
|
1838
2016
|
|
1839
2017
|
|
2018
|
+
class FocalLoss(nn.Module):
|
2019
|
+
def __init__(self, alpha=0.25, gamma=2.0, reduction="mean"):
|
2020
|
+
super().__init__()
|
2021
|
+
self.alpha = alpha
|
2022
|
+
self.gamma = gamma
|
2023
|
+
self.reduction = reduction
|
2024
|
+
|
2025
|
+
def forward(self, logits, targets):
|
2026
|
+
bce_loss = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")
|
2027
|
+
probs = torch.sigmoid(logits)
|
2028
|
+
pt = torch.where(targets == 1, probs, 1 - probs)
|
2029
|
+
focal_term = (1 - pt) ** self.gamma
|
2030
|
+
loss = self.alpha * focal_term * bce_loss
|
2031
|
+
|
2032
|
+
if self.reduction == "mean":
|
2033
|
+
return loss.mean()
|
2034
|
+
elif self.reduction == "sum":
|
2035
|
+
return loss.sum()
|
2036
|
+
return loss
|
2037
|
+
|
2038
|
+
|
1840
2039
|
class CehrGptForClassification(CEHRGPTPreTrainedModel):
|
1841
2040
|
_keep_in_fp32_modules = ["age_batch_norm", "dense_layer", "classifier"]
|
1842
2041
|
|
@@ -1859,7 +2058,6 @@ class CehrGptForClassification(CEHRGPTPreTrainedModel):
|
|
1859
2058
|
self.model_parallel = False
|
1860
2059
|
self.device_map = None
|
1861
2060
|
self.gradient_checkpointing = False
|
1862
|
-
|
1863
2061
|
# Initialize weights and apply final processing
|
1864
2062
|
self.post_init()
|
1865
2063
|
|
@@ -1971,7 +2169,14 @@ class CehrGptForClassification(CEHRGPTPreTrainedModel):
|
|
1971
2169
|
|
1972
2170
|
loss = None
|
1973
2171
|
if classifier_label is not None:
|
1974
|
-
|
2172
|
+
if self.config.class_weights:
|
2173
|
+
class_weights = torch.tensor(
|
2174
|
+
[self.config.class_weights[1] / self.config.class_weights[0]],
|
2175
|
+
dtype=torch.float32,
|
2176
|
+
).to(logits.device)
|
2177
|
+
else:
|
2178
|
+
class_weights = None
|
2179
|
+
loss_fct = nn.BCEWithLogitsLoss(pos_weight=class_weights)
|
1975
2180
|
loss = loss_fct(logits, classifier_label)
|
1976
2181
|
|
1977
2182
|
return CehrGptSequenceClassifierOutput(
|
@@ -85,6 +85,7 @@ class CehrGptCausalLMOutput(ModelOutput):
|
|
85
85
|
time_token_loss: Optional[torch.FloatTensor] = None
|
86
86
|
time_to_visit_loss: Optional[torch.FloatTensor] = None
|
87
87
|
token_value_loss: Optional[torch.FloatTensor] = None
|
88
|
+
motor_tte_loss: Optional[torch.FloatTensor] = None
|
88
89
|
|
89
90
|
|
90
91
|
@dataclass
|