cehrgpt 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,7 +6,7 @@ import numpy as np
6
6
  import torch
7
7
  import torch.nn.functional as f
8
8
  from torch import nn
9
- from torch.distributions import Gamma, Weibull
9
+ from torch.distributions import Exponential, Gamma
10
10
  from torch.nn import CrossEntropyLoss
11
11
  from torch.nn import functional as F
12
12
  from transformers import PreTrainedModel
@@ -362,9 +362,37 @@ class GPT2FlashAttention(GPT2Attention):
362
362
  )
363
363
 
364
364
 
365
- class WeibullModel(nn.Module):
365
+ class MotorTaskHead(nn.Module):
366
+ def __init__(self, input_dim, motor_tte_vocab_size, motor_num_time_pieces):
367
+ super(MotorTaskHead, self).__init__()
368
+ self.input_dim = input_dim
369
+ self.motor_tte_vocab_size = motor_tte_vocab_size
370
+ self.motor_num_time_pieces = motor_num_time_pieces
371
+ self.linear = nn.Sequential(
372
+ nn.Linear(input_dim, input_dim // 2),
373
+ gelu_new,
374
+ nn.Linear(
375
+ input_dim // 2, motor_tte_vocab_size * self.motor_num_time_pieces
376
+ ),
377
+ )
378
+
379
+ def forward(self, x):
380
+ # Ensure scale is positive
381
+ length = x.shape[0]
382
+ # (num_visits_in_batch, motor_tte_vocab_size * motor_num_time_pieces)
383
+ lambda_p = f.softplus(self.linear(x))
384
+ # Check for NaN values
385
+ if torch.isnan(lambda_p).any():
386
+ logger.warning(f"NaN values found in scale_param. x: {x}")
387
+ # (num_visits_in_batch, motor_num_time_pieces, motor_tte_vocab_size,)
388
+ return lambda_p.view(
389
+ length, self.motor_num_time_pieces, self.motor_tte_vocab_size
390
+ )
391
+
392
+
393
+ class VisitTimeToEventHead(nn.Module):
366
394
  def __init__(self, input_dim):
367
- super(WeibullModel, self).__init__()
395
+ super(VisitTimeToEventHead, self).__init__()
368
396
  self.linear1 = nn.Sequential(
369
397
  nn.Linear(input_dim, input_dim // 2), gelu_new, nn.Linear(input_dim // 2, 1)
370
398
  )
@@ -661,32 +689,33 @@ class CEHRGPTPreTrainedModel(PreTrainedModel):
661
689
  hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
662
690
  )
663
691
  wpe = self.get_position_embeddings()
664
- max_position, embed_dim = wpe.weight.shape
665
- if new_num_position_embeddings > max_position:
666
- new_embeddings = nn.Embedding(
667
- new_num_position_embeddings,
668
- embed_dim,
669
- device=wpe.weight.device,
670
- dtype=wpe.weight.dtype,
671
- )
672
-
673
- # initialize all new embeddings (in particular added tokens)
674
- self._init_weights(new_embeddings)
675
- if is_deepspeed_zero3_enabled() and not is_quantized:
676
- import deepspeed
692
+ if wpe is not None:
693
+ max_position, embed_dim = wpe.weight.shape
694
+ if new_num_position_embeddings > max_position:
695
+ new_embeddings = nn.Embedding(
696
+ new_num_position_embeddings,
697
+ embed_dim,
698
+ device=wpe.weight.device,
699
+ dtype=wpe.weight.dtype,
700
+ )
677
701
 
678
- params = [wpe.weight, new_embeddings.weight]
679
- with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
702
+ # initialize all new embeddings (in particular added tokens)
703
+ self._init_weights(new_embeddings)
704
+ if is_deepspeed_zero3_enabled() and not is_quantized:
705
+ import deepspeed
706
+
707
+ params = [wpe.weight, new_embeddings.weight]
708
+ with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
709
+ new_embeddings.weight.data[:max_position, :] = (
710
+ wpe.weight.data[:max_position, :]
711
+ )
712
+ else:
680
713
  new_embeddings.weight.data[:max_position, :] = wpe.weight.data[
681
714
  :max_position, :
682
715
  ]
683
- else:
684
- new_embeddings.weight.data[:max_position, :] = wpe.weight.data[
685
- :max_position, :
686
- ]
687
- self.set_position_embeddings(new_embeddings)
688
- self.config.max_position_embeddings = new_num_position_embeddings
689
- self.update_attn_bias(new_num_position_embeddings)
716
+ self.set_position_embeddings(new_embeddings)
717
+ self.config.max_position_embeddings = new_num_position_embeddings
718
+ self.update_attn_bias(new_num_position_embeddings)
690
719
 
691
720
 
692
721
  class CEHRGPT2Model(CEHRGPTPreTrainedModel):
@@ -740,6 +769,10 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
740
769
  )
741
770
  self.update_attn_bias(self.config.sample_packing_max_positions)
742
771
 
772
+ def enable_position_embeddings(self):
773
+ self.wpe = nn.Embedding(self.config.max_position_embeddings, self.embed_dim)
774
+ self.config.exclude_position_ids = False
775
+
743
776
  def initialize_pretrained_embeddings(self):
744
777
  layers = [
745
778
  nn.Embedding(self.config.vocab_size, self.config.pretrained_embedding_dim),
@@ -1043,7 +1076,7 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
1043
1076
  )
1044
1077
 
1045
1078
  if not self.exclude_position_ids:
1046
- position_embeds = self.wpe(position_ids)
1079
+ position_embeds = self.wpe(position_ids).to(input_embeddings.dtype)
1047
1080
  hidden_states = input_embeddings + position_embeds
1048
1081
  else:
1049
1082
  hidden_states = input_embeddings
@@ -1152,7 +1185,7 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1152
1185
  super().__init__(config)
1153
1186
  self.cehrgpt = CEHRGPT2Model(config)
1154
1187
  if self.config.include_ttv_prediction:
1155
- self.tte_head = WeibullModel(config.n_embd)
1188
+ self.tte_head = VisitTimeToEventHead(config.n_embd)
1156
1189
 
1157
1190
  if self.config.use_sub_time_tokenization:
1158
1191
  self.time_token_lm_head = nn.Linear(
@@ -1165,6 +1198,11 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1165
1198
  config.n_embd, config.value_vocab_size, bias=False
1166
1199
  )
1167
1200
 
1201
+ if self.config.include_motor_time_to_event:
1202
+ self.motor_tte = MotorTaskHead(
1203
+ config.n_embd, config.motor_tte_vocab_size, config.motor_num_time_pieces
1204
+ )
1205
+
1168
1206
  # Model parallel
1169
1207
  self.model_parallel = False
1170
1208
  self.device_map = None
@@ -1192,6 +1230,8 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1192
1230
  self.value_head = self.value_head.to(self.cehrgpt.first_device)
1193
1231
  if self.config.include_ttv_prediction:
1194
1232
  self.tte_head = self.tte_head.to(self.cehrgpt.first_device)
1233
+ if self.config.include_motor_time_to_event:
1234
+ self.motor_tte = self.motor_tte.to(self.cehrgpt.first_device)
1195
1235
  self.model_parallel = True
1196
1236
 
1197
1237
  def deparallelize(self):
@@ -1206,6 +1246,8 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1206
1246
  self.value_head = self.value_head.to("cpu")
1207
1247
  if self.config.include_ttv_prediction:
1208
1248
  self.tte_head = self.tte_head.to("cpu")
1249
+ if self.config.include_motor_time_to_event:
1250
+ self.motor_tte = self.motor_tte.to("cpu")
1209
1251
  self.model_parallel = False
1210
1252
  torch.cuda.empty_cache()
1211
1253
 
@@ -1233,6 +1275,28 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1233
1275
  def update_attn_bias(self, max_position_embeddings: int):
1234
1276
  self.cehrgpt.update_attn_bias(max_position_embeddings)
1235
1277
 
1278
+ def update_motor_tte_vocab_size(
1279
+ self, motor_tte_vocab_size: Optional[int] = None
1280
+ ) -> None:
1281
+ update_motor_tte_layer = False
1282
+ if motor_tte_vocab_size and motor_tte_vocab_size > 0:
1283
+ if self.config.include_motor_time_to_event:
1284
+ if self.config.motor_tte_vocab_size != motor_tte_vocab_size:
1285
+ self.config.include_motor_time_to_event = True
1286
+ self.config.motor_tte_vocab_size = motor_tte_vocab_size
1287
+ update_motor_tte_layer = True
1288
+ else:
1289
+ self.config.include_motor_time_to_event = True
1290
+ self.config.motor_tte_vocab_size = motor_tte_vocab_size
1291
+ update_motor_tte_layer = True
1292
+
1293
+ if update_motor_tte_layer:
1294
+ self.motor_tte = MotorTaskHead(
1295
+ self.config.n_embd,
1296
+ self.config.motor_tte_vocab_size,
1297
+ self.config.motor_num_time_pieces,
1298
+ )
1299
+
1236
1300
  def prepare_inputs_for_generation(
1237
1301
  self,
1238
1302
  input_ids,
@@ -1328,6 +1392,74 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1328
1392
 
1329
1393
  return model_inputs
1330
1394
 
1395
+ def motor_nll_loss(
1396
+ self,
1397
+ ve_token_features,
1398
+ motor_time_to_event_vectors,
1399
+ motor_event_indicators,
1400
+ motor_time_to_event_to_include,
1401
+ motor_time_indicators,
1402
+ batch_motor_end_index,
1403
+ ):
1404
+ """
1405
+ Computes the negative log-likelihood (NLL) loss using the LogNormal distribution.
1406
+
1407
+ for modeling time-to-event data at each visit.
1408
+
1409
+ Args:
1410
+ ve_token_features (Tensor): Hidden representations for the [VE] tokens [num_visits, hidden_dim].
1411
+ motor_time_to_event_vectors (Tensor): Raw time-to-event durations [B, T, motor_vocab_size] (flattened).
1412
+ motor_time_to_event_to_include: (Tensor): Bool indicators (True if included, False if not included).
1413
+ motor_event_indicators (Tensor): Binary indicators (1 if censored, 0 if event occurred).
1414
+ motor_time_indicators (Tensor): Binary indicators whether the time occurs in the current
1415
+ time bucket (1 if censored, 0 if event occurred).
1416
+ batch_motor_end_index (Tensor): Tensor indicating the number of valid [VE] tokens in the batch.
1417
+
1418
+ Returns:
1419
+ Tensor: Scalar loss value (mean negative log-likelihood).
1420
+ """
1421
+ batch_motor_end_index = batch_motor_end_index.sum().item()
1422
+ motor_time_to_event_vectors = motor_time_to_event_vectors.view(
1423
+ (-1, self.config.motor_num_time_pieces, self.config.motor_tte_vocab_size)
1424
+ )[:batch_motor_end_index].clamp(min=1e-3)
1425
+ motor_event_indicators = motor_event_indicators.reshape(
1426
+ (-1, self.config.motor_num_time_pieces, self.config.motor_tte_vocab_size)
1427
+ )[:batch_motor_end_index]
1428
+ motor_time_to_event_to_include = motor_time_to_event_to_include.flatten()[
1429
+ :batch_motor_end_index
1430
+ ]
1431
+ motor_time_indicators = motor_time_indicators.view(
1432
+ (-1, self.config.motor_num_time_pieces, self.config.motor_tte_vocab_size)
1433
+ )[:batch_motor_end_index]
1434
+ assert ve_token_features.shape[0] == motor_time_to_event_vectors.shape[0], (
1435
+ "The number of VE tokens in the labels needs to match up "
1436
+ "with the first dimension of motor_time_to_event_vectors. "
1437
+ f"Received ve_token_features.shape[0]: {ve_token_features.shape[0]}, "
1438
+ f"motor_time_to_event_vectors.shape[0]: {motor_time_to_event_vectors.shape[0]}"
1439
+ )
1440
+ motor_time_to_event_vectors = motor_time_to_event_vectors[
1441
+ motor_time_to_event_to_include
1442
+ ]
1443
+ motor_event_indicators = motor_event_indicators[motor_time_to_event_to_include]
1444
+ motor_time_indicators = motor_time_indicators[motor_time_to_event_to_include]
1445
+ ve_token_features = ve_token_features[motor_time_to_event_to_include]
1446
+
1447
+ # Get Exponential parameters from model
1448
+ lambda_p = self.motor_tte(ve_token_features)
1449
+ # (num_visits_in_batch, num_of_pieces, motor_vocab_size)
1450
+ dist = Exponential(lambda_p.clamp(min=1e-3))
1451
+
1452
+ # Compute event loss
1453
+ tte_loss = torch.where(
1454
+ motor_event_indicators,
1455
+ -dist.log_prob(motor_time_to_event_vectors),
1456
+ -torch.log(
1457
+ 1 - dist.cdf(motor_time_to_event_vectors).clamp(max=1 - 1e-6) + 1e-6
1458
+ ),
1459
+ )
1460
+ tte_loss = torch.where(motor_time_indicators, tte_loss, 0.0)
1461
+ return torch.mean(tte_loss)
1462
+
1331
1463
  def forward(
1332
1464
  self,
1333
1465
  input_ids: Optional[torch.LongTensor] = None,
@@ -1344,6 +1476,11 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1344
1476
  time_to_visits: Optional[torch.FloatTensor] = None,
1345
1477
  time_token_indicators: Optional[torch.BoolTensor] = None,
1346
1478
  sub_time_tokens: Optional[torch.LongTensor] = None,
1479
+ motor_time_to_event_vectors: Optional[torch.FloatTensor] = None,
1480
+ motor_event_indicators: Optional[torch.BoolTensor] = None,
1481
+ motor_time_to_event_to_include: Optional[torch.BoolTensor] = None,
1482
+ motor_time_indicators: Optional[torch.BoolTensor] = None,
1483
+ motor_end_index: Optional[torch.LongTensor] = None,
1347
1484
  use_cache: Optional[bool] = None,
1348
1485
  output_attentions: Optional[bool] = None,
1349
1486
  output_hidden_states: Optional[bool] = None,
@@ -1403,6 +1540,8 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1403
1540
  time_token_loss = None
1404
1541
  time_to_visit_loss = None
1405
1542
  token_value_loss = None
1543
+ motor_tte_loss = None
1544
+
1406
1545
  if labels is not None:
1407
1546
  # move labels to correct device to enable model parallelism
1408
1547
  labels = labels.to(lm_logits.device)
@@ -1470,9 +1609,35 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1470
1609
  entropy_penalty = entropy.sum() / total_num_tokens
1471
1610
  loss += entropy_penalty * self.cehrgpt.config.entropy_penalty_alpha
1472
1611
 
1612
+ if (
1613
+ self.config.include_motor_time_to_event
1614
+ and motor_time_to_event_vectors is not None
1615
+ and motor_event_indicators is not None
1616
+ and motor_time_to_event_to_include is not None
1617
+ and motor_time_indicators is not None
1618
+ and motor_end_index is not None
1619
+ ):
1620
+ ve_token_id_indices = labels == self.config.ve_token_id
1621
+ ve_token_features = hidden_states[ve_token_id_indices]
1622
+ # Get rid of the last VE features because it's already reached the end of the patient sequence and
1623
+ # there is nothing to predict.
1624
+ motor_tte_loss = self.motor_nll_loss(
1625
+ ve_token_features=ve_token_features,
1626
+ motor_time_to_event_vectors=motor_time_to_event_vectors,
1627
+ motor_event_indicators=motor_event_indicators,
1628
+ motor_time_to_event_to_include=motor_time_to_event_to_include,
1629
+ motor_time_indicators=motor_time_indicators,
1630
+ batch_motor_end_index=motor_end_index,
1631
+ )
1632
+ loss += motor_tte_loss * self.config.motor_time_to_event_weight
1633
+
1473
1634
  # We add another loss term when use_sub_time_tokenization is enabled, we need to recover the sub time token
1474
1635
  # predictions for year/month/token
1475
- if self.config.use_sub_time_tokenization:
1636
+ if (
1637
+ self.config.use_sub_time_tokenization
1638
+ and sub_time_tokens is not None
1639
+ and time_token_indicators is not None
1640
+ ):
1476
1641
  # Split the last dimensions into three parts
1477
1642
  time_loss_fct = CrossEntropyLoss(reduction="none")
1478
1643
  time_token_logits = self.time_token_lm_head(
@@ -1501,7 +1666,7 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1501
1666
  time_token_loss = time_token_loss.sum() / total_num_tokens
1502
1667
  loss += time_token_loss * self.config.time_token_loss_weight
1503
1668
 
1504
- if time_to_visits is not None:
1669
+ if time_to_visits is not None and time_to_visits is not None:
1505
1670
  # Get lambda and k parameters
1506
1671
  lambda_param, k_param = self.tte_head(hidden_states)
1507
1672
 
@@ -1512,14 +1677,15 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1512
1677
 
1513
1678
  # Move to the same device as lambda_param
1514
1679
  shift_time_to_visits = shift_time_to_visits.to(lambda_param.device)
1515
-
1516
1680
  time_to_visit_indicator = shift_time_to_visits >= 0
1517
1681
  # Define the Gamma distribution
1518
1682
  dist = Gamma(
1519
1683
  shifted_k_param.squeeze(-1), shifted_lambda_param.squeeze(-1)
1520
1684
  )
1521
1685
  # Compute log-probs and apply the time_to_visit_indicator
1522
- log_probs = dist.log_prob(torch.clamp(shift_time_to_visits, min=1e-3))
1686
+ log_probs = dist.log_prob(
1687
+ torch.clamp(shift_time_to_visits, min=1e-3) + 1e-6
1688
+ )
1523
1689
  log_probs = torch.where(time_to_visit_indicator, log_probs, 0)
1524
1690
  time_to_visit_loss = -log_probs.sum() / total_num_tokens
1525
1691
  # Compute the loss
@@ -1564,6 +1730,7 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1564
1730
  time_token_loss=time_token_loss,
1565
1731
  time_to_visit_loss=time_to_visit_loss,
1566
1732
  token_value_loss=token_value_loss,
1733
+ motor_tte_loss=motor_tte_loss,
1567
1734
  )
1568
1735
 
1569
1736
  @staticmethod
@@ -1837,6 +2004,27 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1837
2004
  )
1838
2005
 
1839
2006
 
2007
+ class FocalLoss(nn.Module):
2008
+ def __init__(self, alpha=0.25, gamma=2.0, reduction="mean"):
2009
+ super().__init__()
2010
+ self.alpha = alpha
2011
+ self.gamma = gamma
2012
+ self.reduction = reduction
2013
+
2014
+ def forward(self, logits, targets):
2015
+ bce_loss = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")
2016
+ probs = torch.sigmoid(logits)
2017
+ pt = torch.where(targets == 1, probs, 1 - probs)
2018
+ focal_term = (1 - pt) ** self.gamma
2019
+ loss = self.alpha * focal_term * bce_loss
2020
+
2021
+ if self.reduction == "mean":
2022
+ return loss.mean()
2023
+ elif self.reduction == "sum":
2024
+ return loss.sum()
2025
+ return loss
2026
+
2027
+
1840
2028
  class CehrGptForClassification(CEHRGPTPreTrainedModel):
1841
2029
  _keep_in_fp32_modules = ["age_batch_norm", "dense_layer", "classifier"]
1842
2030
 
@@ -1859,7 +2047,6 @@ class CehrGptForClassification(CEHRGPTPreTrainedModel):
1859
2047
  self.model_parallel = False
1860
2048
  self.device_map = None
1861
2049
  self.gradient_checkpointing = False
1862
-
1863
2050
  # Initialize weights and apply final processing
1864
2051
  self.post_init()
1865
2052
 
@@ -1971,7 +2158,14 @@ class CehrGptForClassification(CEHRGPTPreTrainedModel):
1971
2158
 
1972
2159
  loss = None
1973
2160
  if classifier_label is not None:
1974
- loss_fct = nn.BCEWithLogitsLoss()
2161
+ if self.config.class_weights:
2162
+ class_weights = torch.tensor(
2163
+ [self.config.class_weights[1] / self.config.class_weights[0]],
2164
+ dtype=torch.float32,
2165
+ ).to(logits.device)
2166
+ else:
2167
+ class_weights = None
2168
+ loss_fct = nn.BCEWithLogitsLoss(pos_weight=class_weights)
1975
2169
  loss = loss_fct(logits, classifier_label)
1976
2170
 
1977
2171
  return CehrGptSequenceClassifierOutput(
@@ -85,6 +85,7 @@ class CehrGptCausalLMOutput(ModelOutput):
85
85
  time_token_loss: Optional[torch.FloatTensor] = None
86
86
  time_to_visit_loss: Optional[torch.FloatTensor] = None
87
87
  token_value_loss: Optional[torch.FloatTensor] = None
88
+ motor_tte_loss: Optional[torch.FloatTensor] = None
88
89
 
89
90
 
90
91
  @dataclass
@@ -3,6 +3,7 @@ START_TOKEN = "[START]"
3
3
  END_TOKEN = "[END]"
4
4
  PAD_TOKEN = "[PAD]"
5
5
  OUT_OF_VOCABULARY_TOKEN = "[OOV]"
6
+ LINEAR_PROB_TOKEN = "[LINEAR_PROB]"
6
7
 
7
8
  # OMOP CONCEPT IDs
8
9
  VISIT_CONCEPT_IDS = [