cehrgpt 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/irregularity.py +36 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +1 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +398 -36
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +214 -12
- cehrgpt/data/sample_packing_sampler.py +36 -6
- cehrgpt/generation/omop_converter_batch.py +32 -2
- cehrgpt/gpt_utils.py +20 -2
- cehrgpt/models/config.py +25 -0
- cehrgpt/models/hf_cehrgpt.py +227 -33
- cehrgpt/models/hf_modeling_outputs.py +1 -0
- cehrgpt/models/special_tokens.py +1 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +354 -71
- cehrgpt/runners/data_utils.py +117 -2
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +75 -50
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +59 -7
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +48 -0
- cehrgpt/runners/hyperparameter_search_util.py +6 -7
- cehrgpt/runners/sample_packing_trainer.py +17 -0
- cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
- cehrgpt/time_to_event/time_to_event_model.py +2 -13
- cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +85 -57
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.1.dist-info}/METADATA +8 -7
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.1.dist-info}/RECORD +27 -25
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.1.dist-info}/WHEEL +1 -1
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.1.dist-info}/top_level.txt +0 -0
cehrgpt/models/hf_cehrgpt.py
CHANGED
@@ -6,7 +6,7 @@ import numpy as np
|
|
6
6
|
import torch
|
7
7
|
import torch.nn.functional as f
|
8
8
|
from torch import nn
|
9
|
-
from torch.distributions import
|
9
|
+
from torch.distributions import Exponential, Gamma
|
10
10
|
from torch.nn import CrossEntropyLoss
|
11
11
|
from torch.nn import functional as F
|
12
12
|
from transformers import PreTrainedModel
|
@@ -362,9 +362,37 @@ class GPT2FlashAttention(GPT2Attention):
|
|
362
362
|
)
|
363
363
|
|
364
364
|
|
365
|
-
class
|
365
|
+
class MotorTaskHead(nn.Module):
|
366
|
+
def __init__(self, input_dim, motor_tte_vocab_size, motor_num_time_pieces):
|
367
|
+
super(MotorTaskHead, self).__init__()
|
368
|
+
self.input_dim = input_dim
|
369
|
+
self.motor_tte_vocab_size = motor_tte_vocab_size
|
370
|
+
self.motor_num_time_pieces = motor_num_time_pieces
|
371
|
+
self.linear = nn.Sequential(
|
372
|
+
nn.Linear(input_dim, input_dim // 2),
|
373
|
+
gelu_new,
|
374
|
+
nn.Linear(
|
375
|
+
input_dim // 2, motor_tte_vocab_size * self.motor_num_time_pieces
|
376
|
+
),
|
377
|
+
)
|
378
|
+
|
379
|
+
def forward(self, x):
|
380
|
+
# Ensure scale is positive
|
381
|
+
length = x.shape[0]
|
382
|
+
# (num_visits_in_batch, motor_tte_vocab_size * motor_num_time_pieces)
|
383
|
+
lambda_p = f.softplus(self.linear(x))
|
384
|
+
# Check for NaN values
|
385
|
+
if torch.isnan(lambda_p).any():
|
386
|
+
logger.warning(f"NaN values found in scale_param. x: {x}")
|
387
|
+
# (num_visits_in_batch, motor_num_time_pieces, motor_tte_vocab_size,)
|
388
|
+
return lambda_p.view(
|
389
|
+
length, self.motor_num_time_pieces, self.motor_tte_vocab_size
|
390
|
+
)
|
391
|
+
|
392
|
+
|
393
|
+
class VisitTimeToEventHead(nn.Module):
|
366
394
|
def __init__(self, input_dim):
|
367
|
-
super(
|
395
|
+
super(VisitTimeToEventHead, self).__init__()
|
368
396
|
self.linear1 = nn.Sequential(
|
369
397
|
nn.Linear(input_dim, input_dim // 2), gelu_new, nn.Linear(input_dim // 2, 1)
|
370
398
|
)
|
@@ -661,32 +689,33 @@ class CEHRGPTPreTrainedModel(PreTrainedModel):
|
|
661
689
|
hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
|
662
690
|
)
|
663
691
|
wpe = self.get_position_embeddings()
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
# initialize all new embeddings (in particular added tokens)
|
674
|
-
self._init_weights(new_embeddings)
|
675
|
-
if is_deepspeed_zero3_enabled() and not is_quantized:
|
676
|
-
import deepspeed
|
692
|
+
if wpe is not None:
|
693
|
+
max_position, embed_dim = wpe.weight.shape
|
694
|
+
if new_num_position_embeddings > max_position:
|
695
|
+
new_embeddings = nn.Embedding(
|
696
|
+
new_num_position_embeddings,
|
697
|
+
embed_dim,
|
698
|
+
device=wpe.weight.device,
|
699
|
+
dtype=wpe.weight.dtype,
|
700
|
+
)
|
677
701
|
|
678
|
-
|
679
|
-
|
702
|
+
# initialize all new embeddings (in particular added tokens)
|
703
|
+
self._init_weights(new_embeddings)
|
704
|
+
if is_deepspeed_zero3_enabled() and not is_quantized:
|
705
|
+
import deepspeed
|
706
|
+
|
707
|
+
params = [wpe.weight, new_embeddings.weight]
|
708
|
+
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
|
709
|
+
new_embeddings.weight.data[:max_position, :] = (
|
710
|
+
wpe.weight.data[:max_position, :]
|
711
|
+
)
|
712
|
+
else:
|
680
713
|
new_embeddings.weight.data[:max_position, :] = wpe.weight.data[
|
681
714
|
:max_position, :
|
682
715
|
]
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
]
|
687
|
-
self.set_position_embeddings(new_embeddings)
|
688
|
-
self.config.max_position_embeddings = new_num_position_embeddings
|
689
|
-
self.update_attn_bias(new_num_position_embeddings)
|
716
|
+
self.set_position_embeddings(new_embeddings)
|
717
|
+
self.config.max_position_embeddings = new_num_position_embeddings
|
718
|
+
self.update_attn_bias(new_num_position_embeddings)
|
690
719
|
|
691
720
|
|
692
721
|
class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
@@ -740,6 +769,10 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
740
769
|
)
|
741
770
|
self.update_attn_bias(self.config.sample_packing_max_positions)
|
742
771
|
|
772
|
+
def enable_position_embeddings(self):
|
773
|
+
self.wpe = nn.Embedding(self.config.max_position_embeddings, self.embed_dim)
|
774
|
+
self.config.exclude_position_ids = False
|
775
|
+
|
743
776
|
def initialize_pretrained_embeddings(self):
|
744
777
|
layers = [
|
745
778
|
nn.Embedding(self.config.vocab_size, self.config.pretrained_embedding_dim),
|
@@ -1043,7 +1076,7 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
1043
1076
|
)
|
1044
1077
|
|
1045
1078
|
if not self.exclude_position_ids:
|
1046
|
-
position_embeds = self.wpe(position_ids)
|
1079
|
+
position_embeds = self.wpe(position_ids).to(input_embeddings.dtype)
|
1047
1080
|
hidden_states = input_embeddings + position_embeds
|
1048
1081
|
else:
|
1049
1082
|
hidden_states = input_embeddings
|
@@ -1152,7 +1185,7 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1152
1185
|
super().__init__(config)
|
1153
1186
|
self.cehrgpt = CEHRGPT2Model(config)
|
1154
1187
|
if self.config.include_ttv_prediction:
|
1155
|
-
self.tte_head =
|
1188
|
+
self.tte_head = VisitTimeToEventHead(config.n_embd)
|
1156
1189
|
|
1157
1190
|
if self.config.use_sub_time_tokenization:
|
1158
1191
|
self.time_token_lm_head = nn.Linear(
|
@@ -1165,6 +1198,11 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1165
1198
|
config.n_embd, config.value_vocab_size, bias=False
|
1166
1199
|
)
|
1167
1200
|
|
1201
|
+
if self.config.include_motor_time_to_event:
|
1202
|
+
self.motor_tte = MotorTaskHead(
|
1203
|
+
config.n_embd, config.motor_tte_vocab_size, config.motor_num_time_pieces
|
1204
|
+
)
|
1205
|
+
|
1168
1206
|
# Model parallel
|
1169
1207
|
self.model_parallel = False
|
1170
1208
|
self.device_map = None
|
@@ -1192,6 +1230,8 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1192
1230
|
self.value_head = self.value_head.to(self.cehrgpt.first_device)
|
1193
1231
|
if self.config.include_ttv_prediction:
|
1194
1232
|
self.tte_head = self.tte_head.to(self.cehrgpt.first_device)
|
1233
|
+
if self.config.include_motor_time_to_event:
|
1234
|
+
self.motor_tte = self.motor_tte.to(self.cehrgpt.first_device)
|
1195
1235
|
self.model_parallel = True
|
1196
1236
|
|
1197
1237
|
def deparallelize(self):
|
@@ -1206,6 +1246,8 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1206
1246
|
self.value_head = self.value_head.to("cpu")
|
1207
1247
|
if self.config.include_ttv_prediction:
|
1208
1248
|
self.tte_head = self.tte_head.to("cpu")
|
1249
|
+
if self.config.include_motor_time_to_event:
|
1250
|
+
self.motor_tte = self.motor_tte.to("cpu")
|
1209
1251
|
self.model_parallel = False
|
1210
1252
|
torch.cuda.empty_cache()
|
1211
1253
|
|
@@ -1233,6 +1275,28 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1233
1275
|
def update_attn_bias(self, max_position_embeddings: int):
|
1234
1276
|
self.cehrgpt.update_attn_bias(max_position_embeddings)
|
1235
1277
|
|
1278
|
+
def update_motor_tte_vocab_size(
|
1279
|
+
self, motor_tte_vocab_size: Optional[int] = None
|
1280
|
+
) -> None:
|
1281
|
+
update_motor_tte_layer = False
|
1282
|
+
if motor_tte_vocab_size and motor_tte_vocab_size > 0:
|
1283
|
+
if self.config.include_motor_time_to_event:
|
1284
|
+
if self.config.motor_tte_vocab_size != motor_tte_vocab_size:
|
1285
|
+
self.config.include_motor_time_to_event = True
|
1286
|
+
self.config.motor_tte_vocab_size = motor_tte_vocab_size
|
1287
|
+
update_motor_tte_layer = True
|
1288
|
+
else:
|
1289
|
+
self.config.include_motor_time_to_event = True
|
1290
|
+
self.config.motor_tte_vocab_size = motor_tte_vocab_size
|
1291
|
+
update_motor_tte_layer = True
|
1292
|
+
|
1293
|
+
if update_motor_tte_layer:
|
1294
|
+
self.motor_tte = MotorTaskHead(
|
1295
|
+
self.config.n_embd,
|
1296
|
+
self.config.motor_tte_vocab_size,
|
1297
|
+
self.config.motor_num_time_pieces,
|
1298
|
+
)
|
1299
|
+
|
1236
1300
|
def prepare_inputs_for_generation(
|
1237
1301
|
self,
|
1238
1302
|
input_ids,
|
@@ -1328,6 +1392,74 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1328
1392
|
|
1329
1393
|
return model_inputs
|
1330
1394
|
|
1395
|
+
def motor_nll_loss(
|
1396
|
+
self,
|
1397
|
+
ve_token_features,
|
1398
|
+
motor_time_to_event_vectors,
|
1399
|
+
motor_event_indicators,
|
1400
|
+
motor_time_to_event_to_include,
|
1401
|
+
motor_time_indicators,
|
1402
|
+
batch_motor_end_index,
|
1403
|
+
):
|
1404
|
+
"""
|
1405
|
+
Computes the negative log-likelihood (NLL) loss using the LogNormal distribution.
|
1406
|
+
|
1407
|
+
for modeling time-to-event data at each visit.
|
1408
|
+
|
1409
|
+
Args:
|
1410
|
+
ve_token_features (Tensor): Hidden representations for the [VE] tokens [num_visits, hidden_dim].
|
1411
|
+
motor_time_to_event_vectors (Tensor): Raw time-to-event durations [B, T, motor_vocab_size] (flattened).
|
1412
|
+
motor_time_to_event_to_include: (Tensor): Bool indicators (True if included, False if not included).
|
1413
|
+
motor_event_indicators (Tensor): Binary indicators (1 if censored, 0 if event occurred).
|
1414
|
+
motor_time_indicators (Tensor): Binary indicators whether the time occurs in the current
|
1415
|
+
time bucket (1 if censored, 0 if event occurred).
|
1416
|
+
batch_motor_end_index (Tensor): Tensor indicating the number of valid [VE] tokens in the batch.
|
1417
|
+
|
1418
|
+
Returns:
|
1419
|
+
Tensor: Scalar loss value (mean negative log-likelihood).
|
1420
|
+
"""
|
1421
|
+
batch_motor_end_index = batch_motor_end_index.sum().item()
|
1422
|
+
motor_time_to_event_vectors = motor_time_to_event_vectors.view(
|
1423
|
+
(-1, self.config.motor_num_time_pieces, self.config.motor_tte_vocab_size)
|
1424
|
+
)[:batch_motor_end_index].clamp(min=1e-3)
|
1425
|
+
motor_event_indicators = motor_event_indicators.reshape(
|
1426
|
+
(-1, self.config.motor_num_time_pieces, self.config.motor_tte_vocab_size)
|
1427
|
+
)[:batch_motor_end_index]
|
1428
|
+
motor_time_to_event_to_include = motor_time_to_event_to_include.flatten()[
|
1429
|
+
:batch_motor_end_index
|
1430
|
+
]
|
1431
|
+
motor_time_indicators = motor_time_indicators.view(
|
1432
|
+
(-1, self.config.motor_num_time_pieces, self.config.motor_tte_vocab_size)
|
1433
|
+
)[:batch_motor_end_index]
|
1434
|
+
assert ve_token_features.shape[0] == motor_time_to_event_vectors.shape[0], (
|
1435
|
+
"The number of VE tokens in the labels needs to match up "
|
1436
|
+
"with the first dimension of motor_time_to_event_vectors. "
|
1437
|
+
f"Received ve_token_features.shape[0]: {ve_token_features.shape[0]}, "
|
1438
|
+
f"motor_time_to_event_vectors.shape[0]: {motor_time_to_event_vectors.shape[0]}"
|
1439
|
+
)
|
1440
|
+
motor_time_to_event_vectors = motor_time_to_event_vectors[
|
1441
|
+
motor_time_to_event_to_include
|
1442
|
+
]
|
1443
|
+
motor_event_indicators = motor_event_indicators[motor_time_to_event_to_include]
|
1444
|
+
motor_time_indicators = motor_time_indicators[motor_time_to_event_to_include]
|
1445
|
+
ve_token_features = ve_token_features[motor_time_to_event_to_include]
|
1446
|
+
|
1447
|
+
# Get Exponential parameters from model
|
1448
|
+
lambda_p = self.motor_tte(ve_token_features)
|
1449
|
+
# (num_visits_in_batch, num_of_pieces, motor_vocab_size)
|
1450
|
+
dist = Exponential(lambda_p.clamp(min=1e-3))
|
1451
|
+
|
1452
|
+
# Compute event loss
|
1453
|
+
tte_loss = torch.where(
|
1454
|
+
motor_event_indicators,
|
1455
|
+
-dist.log_prob(motor_time_to_event_vectors),
|
1456
|
+
-torch.log(
|
1457
|
+
1 - dist.cdf(motor_time_to_event_vectors).clamp(max=1 - 1e-6) + 1e-6
|
1458
|
+
),
|
1459
|
+
)
|
1460
|
+
tte_loss = torch.where(motor_time_indicators, tte_loss, 0.0)
|
1461
|
+
return torch.mean(tte_loss)
|
1462
|
+
|
1331
1463
|
def forward(
|
1332
1464
|
self,
|
1333
1465
|
input_ids: Optional[torch.LongTensor] = None,
|
@@ -1344,6 +1476,11 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1344
1476
|
time_to_visits: Optional[torch.FloatTensor] = None,
|
1345
1477
|
time_token_indicators: Optional[torch.BoolTensor] = None,
|
1346
1478
|
sub_time_tokens: Optional[torch.LongTensor] = None,
|
1479
|
+
motor_time_to_event_vectors: Optional[torch.FloatTensor] = None,
|
1480
|
+
motor_event_indicators: Optional[torch.BoolTensor] = None,
|
1481
|
+
motor_time_to_event_to_include: Optional[torch.BoolTensor] = None,
|
1482
|
+
motor_time_indicators: Optional[torch.BoolTensor] = None,
|
1483
|
+
motor_end_index: Optional[torch.LongTensor] = None,
|
1347
1484
|
use_cache: Optional[bool] = None,
|
1348
1485
|
output_attentions: Optional[bool] = None,
|
1349
1486
|
output_hidden_states: Optional[bool] = None,
|
@@ -1403,6 +1540,8 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1403
1540
|
time_token_loss = None
|
1404
1541
|
time_to_visit_loss = None
|
1405
1542
|
token_value_loss = None
|
1543
|
+
motor_tte_loss = None
|
1544
|
+
|
1406
1545
|
if labels is not None:
|
1407
1546
|
# move labels to correct device to enable model parallelism
|
1408
1547
|
labels = labels.to(lm_logits.device)
|
@@ -1470,9 +1609,35 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1470
1609
|
entropy_penalty = entropy.sum() / total_num_tokens
|
1471
1610
|
loss += entropy_penalty * self.cehrgpt.config.entropy_penalty_alpha
|
1472
1611
|
|
1612
|
+
if (
|
1613
|
+
self.config.include_motor_time_to_event
|
1614
|
+
and motor_time_to_event_vectors is not None
|
1615
|
+
and motor_event_indicators is not None
|
1616
|
+
and motor_time_to_event_to_include is not None
|
1617
|
+
and motor_time_indicators is not None
|
1618
|
+
and motor_end_index is not None
|
1619
|
+
):
|
1620
|
+
ve_token_id_indices = labels == self.config.ve_token_id
|
1621
|
+
ve_token_features = hidden_states[ve_token_id_indices]
|
1622
|
+
# Get rid of the last VE features because it's already reached the end of the patient sequence and
|
1623
|
+
# there is nothing to predict.
|
1624
|
+
motor_tte_loss = self.motor_nll_loss(
|
1625
|
+
ve_token_features=ve_token_features,
|
1626
|
+
motor_time_to_event_vectors=motor_time_to_event_vectors,
|
1627
|
+
motor_event_indicators=motor_event_indicators,
|
1628
|
+
motor_time_to_event_to_include=motor_time_to_event_to_include,
|
1629
|
+
motor_time_indicators=motor_time_indicators,
|
1630
|
+
batch_motor_end_index=motor_end_index,
|
1631
|
+
)
|
1632
|
+
loss += motor_tte_loss * self.config.motor_time_to_event_weight
|
1633
|
+
|
1473
1634
|
# We add another loss term when use_sub_time_tokenization is enabled, we need to recover the sub time token
|
1474
1635
|
# predictions for year/month/token
|
1475
|
-
if
|
1636
|
+
if (
|
1637
|
+
self.config.use_sub_time_tokenization
|
1638
|
+
and sub_time_tokens is not None
|
1639
|
+
and time_token_indicators is not None
|
1640
|
+
):
|
1476
1641
|
# Split the last dimensions into three parts
|
1477
1642
|
time_loss_fct = CrossEntropyLoss(reduction="none")
|
1478
1643
|
time_token_logits = self.time_token_lm_head(
|
@@ -1501,7 +1666,7 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1501
1666
|
time_token_loss = time_token_loss.sum() / total_num_tokens
|
1502
1667
|
loss += time_token_loss * self.config.time_token_loss_weight
|
1503
1668
|
|
1504
|
-
if time_to_visits is not None:
|
1669
|
+
if time_to_visits is not None and time_to_visits is not None:
|
1505
1670
|
# Get lambda and k parameters
|
1506
1671
|
lambda_param, k_param = self.tte_head(hidden_states)
|
1507
1672
|
|
@@ -1512,14 +1677,15 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1512
1677
|
|
1513
1678
|
# Move to the same device as lambda_param
|
1514
1679
|
shift_time_to_visits = shift_time_to_visits.to(lambda_param.device)
|
1515
|
-
|
1516
1680
|
time_to_visit_indicator = shift_time_to_visits >= 0
|
1517
1681
|
# Define the Gamma distribution
|
1518
1682
|
dist = Gamma(
|
1519
1683
|
shifted_k_param.squeeze(-1), shifted_lambda_param.squeeze(-1)
|
1520
1684
|
)
|
1521
1685
|
# Compute log-probs and apply the time_to_visit_indicator
|
1522
|
-
log_probs = dist.log_prob(
|
1686
|
+
log_probs = dist.log_prob(
|
1687
|
+
torch.clamp(shift_time_to_visits, min=1e-3) + 1e-6
|
1688
|
+
)
|
1523
1689
|
log_probs = torch.where(time_to_visit_indicator, log_probs, 0)
|
1524
1690
|
time_to_visit_loss = -log_probs.sum() / total_num_tokens
|
1525
1691
|
# Compute the loss
|
@@ -1564,6 +1730,7 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1564
1730
|
time_token_loss=time_token_loss,
|
1565
1731
|
time_to_visit_loss=time_to_visit_loss,
|
1566
1732
|
token_value_loss=token_value_loss,
|
1733
|
+
motor_tte_loss=motor_tte_loss,
|
1567
1734
|
)
|
1568
1735
|
|
1569
1736
|
@staticmethod
|
@@ -1837,6 +2004,27 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1837
2004
|
)
|
1838
2005
|
|
1839
2006
|
|
2007
|
+
class FocalLoss(nn.Module):
|
2008
|
+
def __init__(self, alpha=0.25, gamma=2.0, reduction="mean"):
|
2009
|
+
super().__init__()
|
2010
|
+
self.alpha = alpha
|
2011
|
+
self.gamma = gamma
|
2012
|
+
self.reduction = reduction
|
2013
|
+
|
2014
|
+
def forward(self, logits, targets):
|
2015
|
+
bce_loss = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")
|
2016
|
+
probs = torch.sigmoid(logits)
|
2017
|
+
pt = torch.where(targets == 1, probs, 1 - probs)
|
2018
|
+
focal_term = (1 - pt) ** self.gamma
|
2019
|
+
loss = self.alpha * focal_term * bce_loss
|
2020
|
+
|
2021
|
+
if self.reduction == "mean":
|
2022
|
+
return loss.mean()
|
2023
|
+
elif self.reduction == "sum":
|
2024
|
+
return loss.sum()
|
2025
|
+
return loss
|
2026
|
+
|
2027
|
+
|
1840
2028
|
class CehrGptForClassification(CEHRGPTPreTrainedModel):
|
1841
2029
|
_keep_in_fp32_modules = ["age_batch_norm", "dense_layer", "classifier"]
|
1842
2030
|
|
@@ -1859,7 +2047,6 @@ class CehrGptForClassification(CEHRGPTPreTrainedModel):
|
|
1859
2047
|
self.model_parallel = False
|
1860
2048
|
self.device_map = None
|
1861
2049
|
self.gradient_checkpointing = False
|
1862
|
-
|
1863
2050
|
# Initialize weights and apply final processing
|
1864
2051
|
self.post_init()
|
1865
2052
|
|
@@ -1971,7 +2158,14 @@ class CehrGptForClassification(CEHRGPTPreTrainedModel):
|
|
1971
2158
|
|
1972
2159
|
loss = None
|
1973
2160
|
if classifier_label is not None:
|
1974
|
-
|
2161
|
+
if self.config.class_weights:
|
2162
|
+
class_weights = torch.tensor(
|
2163
|
+
[self.config.class_weights[1] / self.config.class_weights[0]],
|
2164
|
+
dtype=torch.float32,
|
2165
|
+
).to(logits.device)
|
2166
|
+
else:
|
2167
|
+
class_weights = None
|
2168
|
+
loss_fct = nn.BCEWithLogitsLoss(pos_weight=class_weights)
|
1975
2169
|
loss = loss_fct(logits, classifier_label)
|
1976
2170
|
|
1977
2171
|
return CehrGptSequenceClassifierOutput(
|
@@ -85,6 +85,7 @@ class CehrGptCausalLMOutput(ModelOutput):
|
|
85
85
|
time_token_loss: Optional[torch.FloatTensor] = None
|
86
86
|
time_to_visit_loss: Optional[torch.FloatTensor] = None
|
87
87
|
token_value_loss: Optional[torch.FloatTensor] = None
|
88
|
+
motor_tte_loss: Optional[torch.FloatTensor] = None
|
88
89
|
|
89
90
|
|
90
91
|
@dataclass
|