cehrgpt 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. cehrgpt/analysis/irregularity.py +36 -0
  2. cehrgpt/data/hf_cehrgpt_dataset.py +1 -0
  3. cehrgpt/data/hf_cehrgpt_dataset_collator.py +454 -68
  4. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +232 -17
  5. cehrgpt/data/sample_packing_sampler.py +36 -6
  6. cehrgpt/generation/cehrgpt_conditional_generation.py +314 -0
  7. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +15 -3
  8. cehrgpt/generation/omop_converter_batch.py +32 -2
  9. cehrgpt/gpt_utils.py +20 -2
  10. cehrgpt/models/config.py +25 -0
  11. cehrgpt/models/hf_cehrgpt.py +244 -39
  12. cehrgpt/models/hf_modeling_outputs.py +1 -0
  13. cehrgpt/models/special_tokens.py +1 -0
  14. cehrgpt/models/tokenization_hf_cehrgpt.py +354 -71
  15. cehrgpt/runners/data_utils.py +131 -5
  16. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +84 -51
  17. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +59 -7
  18. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +60 -0
  19. cehrgpt/runners/hyperparameter_search_util.py +6 -7
  20. cehrgpt/runners/sample_packing_trainer.py +17 -0
  21. cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
  22. cehrgpt/time_to_event/time_to_event_model.py +2 -13
  23. cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
  24. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +80 -62
  25. {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/METADATA +102 -7
  26. {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/RECORD +29 -26
  27. {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/WHEEL +1 -1
  28. {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/licenses/LICENSE +0 -0
  29. {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/top_level.txt +0 -0
@@ -6,7 +6,7 @@ import numpy as np
6
6
  import torch
7
7
  import torch.nn.functional as f
8
8
  from torch import nn
9
- from torch.distributions import Gamma, Weibull
9
+ from torch.distributions import Exponential, Gamma
10
10
  from torch.nn import CrossEntropyLoss
11
11
  from torch.nn import functional as F
12
12
  from transformers import PreTrainedModel
@@ -102,7 +102,9 @@ def is_sample_pack(attention_mask: torch.Tensor) -> bool:
102
102
  attention_mask = attention_mask.flip(dims=[1])
103
103
 
104
104
  nonzero_counts = attention_mask.sum(dim=1)
105
- max_token_positions = torch.argmax(attention_mask.flip(dims=[1]), dim=1)
105
+ max_token_positions = torch.argmax(
106
+ attention_mask.to(torch.int32).flip(dims=[1]), dim=1
107
+ )
106
108
  max_indices = attention_mask.shape[1] - 1 - max_token_positions
107
109
  return torch.any(nonzero_counts < (max_indices + 1)).item()
108
110
 
@@ -362,9 +364,37 @@ class GPT2FlashAttention(GPT2Attention):
362
364
  )
363
365
 
364
366
 
365
- class WeibullModel(nn.Module):
367
+ class MotorTaskHead(nn.Module):
368
+ def __init__(self, input_dim, motor_tte_vocab_size, motor_num_time_pieces):
369
+ super(MotorTaskHead, self).__init__()
370
+ self.input_dim = input_dim
371
+ self.motor_tte_vocab_size = motor_tte_vocab_size
372
+ self.motor_num_time_pieces = motor_num_time_pieces
373
+ self.linear = nn.Sequential(
374
+ nn.Linear(input_dim, input_dim // 2),
375
+ gelu_new,
376
+ nn.Linear(
377
+ input_dim // 2, motor_tte_vocab_size * self.motor_num_time_pieces
378
+ ),
379
+ )
380
+
381
+ def forward(self, x):
382
+ # Ensure scale is positive
383
+ length = x.shape[0]
384
+ # (num_visits_in_batch, motor_tte_vocab_size * motor_num_time_pieces)
385
+ lambda_p = f.softplus(self.linear(x))
386
+ # Check for NaN values
387
+ if torch.isnan(lambda_p).any():
388
+ logger.warning(f"NaN values found in scale_param. x: {x}")
389
+ # (num_visits_in_batch, motor_num_time_pieces, motor_tte_vocab_size,)
390
+ return lambda_p.view(
391
+ length, self.motor_num_time_pieces, self.motor_tte_vocab_size
392
+ )
393
+
394
+
395
+ class VisitTimeToEventHead(nn.Module):
366
396
  def __init__(self, input_dim):
367
- super(WeibullModel, self).__init__()
397
+ super(VisitTimeToEventHead, self).__init__()
368
398
  self.linear1 = nn.Sequential(
369
399
  nn.Linear(input_dim, input_dim // 2), gelu_new, nn.Linear(input_dim // 2, 1)
370
400
  )
@@ -661,32 +691,33 @@ class CEHRGPTPreTrainedModel(PreTrainedModel):
661
691
  hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
662
692
  )
663
693
  wpe = self.get_position_embeddings()
664
- max_position, embed_dim = wpe.weight.shape
665
- if new_num_position_embeddings > max_position:
666
- new_embeddings = nn.Embedding(
667
- new_num_position_embeddings,
668
- embed_dim,
669
- device=wpe.weight.device,
670
- dtype=wpe.weight.dtype,
671
- )
672
-
673
- # initialize all new embeddings (in particular added tokens)
674
- self._init_weights(new_embeddings)
675
- if is_deepspeed_zero3_enabled() and not is_quantized:
676
- import deepspeed
694
+ if wpe is not None:
695
+ max_position, embed_dim = wpe.weight.shape
696
+ if new_num_position_embeddings > max_position:
697
+ new_embeddings = nn.Embedding(
698
+ new_num_position_embeddings,
699
+ embed_dim,
700
+ device=wpe.weight.device,
701
+ dtype=wpe.weight.dtype,
702
+ )
677
703
 
678
- params = [wpe.weight, new_embeddings.weight]
679
- with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
704
+ # initialize all new embeddings (in particular added tokens)
705
+ self._init_weights(new_embeddings)
706
+ if is_deepspeed_zero3_enabled() and not is_quantized:
707
+ import deepspeed
708
+
709
+ params = [wpe.weight, new_embeddings.weight]
710
+ with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
711
+ new_embeddings.weight.data[:max_position, :] = (
712
+ wpe.weight.data[:max_position, :]
713
+ )
714
+ else:
680
715
  new_embeddings.weight.data[:max_position, :] = wpe.weight.data[
681
716
  :max_position, :
682
717
  ]
683
- else:
684
- new_embeddings.weight.data[:max_position, :] = wpe.weight.data[
685
- :max_position, :
686
- ]
687
- self.set_position_embeddings(new_embeddings)
688
- self.config.max_position_embeddings = new_num_position_embeddings
689
- self.update_attn_bias(new_num_position_embeddings)
718
+ self.set_position_embeddings(new_embeddings)
719
+ self.config.max_position_embeddings = new_num_position_embeddings
720
+ self.update_attn_bias(new_num_position_embeddings)
690
721
 
691
722
 
692
723
  class CEHRGPT2Model(CEHRGPTPreTrainedModel):
@@ -740,6 +771,10 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
740
771
  )
741
772
  self.update_attn_bias(self.config.sample_packing_max_positions)
742
773
 
774
+ def enable_position_embeddings(self):
775
+ self.wpe = nn.Embedding(self.config.max_position_embeddings, self.embed_dim)
776
+ self.config.exclude_position_ids = False
777
+
743
778
  def initialize_pretrained_embeddings(self):
744
779
  layers = [
745
780
  nn.Embedding(self.config.vocab_size, self.config.pretrained_embedding_dim),
@@ -1043,7 +1078,7 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
1043
1078
  )
1044
1079
 
1045
1080
  if not self.exclude_position_ids:
1046
- position_embeds = self.wpe(position_ids)
1081
+ position_embeds = self.wpe(position_ids).to(input_embeddings.dtype)
1047
1082
  hidden_states = input_embeddings + position_embeds
1048
1083
  else:
1049
1084
  hidden_states = input_embeddings
@@ -1152,7 +1187,7 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1152
1187
  super().__init__(config)
1153
1188
  self.cehrgpt = CEHRGPT2Model(config)
1154
1189
  if self.config.include_ttv_prediction:
1155
- self.tte_head = WeibullModel(config.n_embd)
1190
+ self.tte_head = VisitTimeToEventHead(config.n_embd)
1156
1191
 
1157
1192
  if self.config.use_sub_time_tokenization:
1158
1193
  self.time_token_lm_head = nn.Linear(
@@ -1165,6 +1200,11 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1165
1200
  config.n_embd, config.value_vocab_size, bias=False
1166
1201
  )
1167
1202
 
1203
+ if self.config.include_motor_time_to_event:
1204
+ self.motor_tte = MotorTaskHead(
1205
+ config.n_embd, config.motor_tte_vocab_size, config.motor_num_time_pieces
1206
+ )
1207
+
1168
1208
  # Model parallel
1169
1209
  self.model_parallel = False
1170
1210
  self.device_map = None
@@ -1192,6 +1232,8 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1192
1232
  self.value_head = self.value_head.to(self.cehrgpt.first_device)
1193
1233
  if self.config.include_ttv_prediction:
1194
1234
  self.tte_head = self.tte_head.to(self.cehrgpt.first_device)
1235
+ if self.config.include_motor_time_to_event:
1236
+ self.motor_tte = self.motor_tte.to(self.cehrgpt.first_device)
1195
1237
  self.model_parallel = True
1196
1238
 
1197
1239
  def deparallelize(self):
@@ -1206,6 +1248,8 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1206
1248
  self.value_head = self.value_head.to("cpu")
1207
1249
  if self.config.include_ttv_prediction:
1208
1250
  self.tte_head = self.tte_head.to("cpu")
1251
+ if self.config.include_motor_time_to_event:
1252
+ self.motor_tte = self.motor_tte.to("cpu")
1209
1253
  self.model_parallel = False
1210
1254
  torch.cuda.empty_cache()
1211
1255
 
@@ -1233,6 +1277,28 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1233
1277
  def update_attn_bias(self, max_position_embeddings: int):
1234
1278
  self.cehrgpt.update_attn_bias(max_position_embeddings)
1235
1279
 
1280
+ def update_motor_tte_vocab_size(
1281
+ self, motor_tte_vocab_size: Optional[int] = None
1282
+ ) -> None:
1283
+ update_motor_tte_layer = False
1284
+ if motor_tte_vocab_size and motor_tte_vocab_size > 0:
1285
+ if self.config.include_motor_time_to_event:
1286
+ if self.config.motor_tte_vocab_size != motor_tte_vocab_size:
1287
+ self.config.include_motor_time_to_event = True
1288
+ self.config.motor_tte_vocab_size = motor_tte_vocab_size
1289
+ update_motor_tte_layer = True
1290
+ else:
1291
+ self.config.include_motor_time_to_event = True
1292
+ self.config.motor_tte_vocab_size = motor_tte_vocab_size
1293
+ update_motor_tte_layer = True
1294
+
1295
+ if update_motor_tte_layer:
1296
+ self.motor_tte = MotorTaskHead(
1297
+ self.config.n_embd,
1298
+ self.config.motor_tte_vocab_size,
1299
+ self.config.motor_num_time_pieces,
1300
+ )
1301
+
1236
1302
  def prepare_inputs_for_generation(
1237
1303
  self,
1238
1304
  input_ids,
@@ -1328,6 +1394,74 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1328
1394
 
1329
1395
  return model_inputs
1330
1396
 
1397
+ def motor_nll_loss(
1398
+ self,
1399
+ ve_token_features,
1400
+ motor_time_to_event_vectors,
1401
+ motor_event_indicators,
1402
+ motor_time_to_event_to_include,
1403
+ motor_time_indicators,
1404
+ batch_motor_end_index,
1405
+ ):
1406
+ """
1407
+ Computes the negative log-likelihood (NLL) loss using the LogNormal distribution.
1408
+
1409
+ for modeling time-to-event data at each visit.
1410
+
1411
+ Args:
1412
+ ve_token_features (Tensor): Hidden representations for the [VE] tokens [num_visits, hidden_dim].
1413
+ motor_time_to_event_vectors (Tensor): Raw time-to-event durations [B, T, motor_vocab_size] (flattened).
1414
+ motor_time_to_event_to_include: (Tensor): Bool indicators (True if included, False if not included).
1415
+ motor_event_indicators (Tensor): Binary indicators (1 if censored, 0 if event occurred).
1416
+ motor_time_indicators (Tensor): Binary indicators whether the time occurs in the current
1417
+ time bucket (1 if censored, 0 if event occurred).
1418
+ batch_motor_end_index (Tensor): Tensor indicating the number of valid [VE] tokens in the batch.
1419
+
1420
+ Returns:
1421
+ Tensor: Scalar loss value (mean negative log-likelihood).
1422
+ """
1423
+ batch_motor_end_index = batch_motor_end_index.sum().item()
1424
+ motor_time_to_event_vectors = motor_time_to_event_vectors.view(
1425
+ (-1, self.config.motor_num_time_pieces, self.config.motor_tte_vocab_size)
1426
+ )[:batch_motor_end_index].clamp(min=1e-3)
1427
+ motor_event_indicators = motor_event_indicators.reshape(
1428
+ (-1, self.config.motor_num_time_pieces, self.config.motor_tte_vocab_size)
1429
+ )[:batch_motor_end_index]
1430
+ motor_time_to_event_to_include = motor_time_to_event_to_include.flatten()[
1431
+ :batch_motor_end_index
1432
+ ]
1433
+ motor_time_indicators = motor_time_indicators.view(
1434
+ (-1, self.config.motor_num_time_pieces, self.config.motor_tte_vocab_size)
1435
+ )[:batch_motor_end_index]
1436
+ assert ve_token_features.shape[0] == motor_time_to_event_vectors.shape[0], (
1437
+ "The number of VE tokens in the labels needs to match up "
1438
+ "with the first dimension of motor_time_to_event_vectors. "
1439
+ f"Received ve_token_features.shape[0]: {ve_token_features.shape[0]}, "
1440
+ f"motor_time_to_event_vectors.shape[0]: {motor_time_to_event_vectors.shape[0]}"
1441
+ )
1442
+ motor_time_to_event_vectors = motor_time_to_event_vectors[
1443
+ motor_time_to_event_to_include
1444
+ ]
1445
+ motor_event_indicators = motor_event_indicators[motor_time_to_event_to_include]
1446
+ motor_time_indicators = motor_time_indicators[motor_time_to_event_to_include]
1447
+ ve_token_features = ve_token_features[motor_time_to_event_to_include]
1448
+
1449
+ # Get Exponential parameters from model
1450
+ lambda_p = self.motor_tte(ve_token_features)
1451
+ # (num_visits_in_batch, num_of_pieces, motor_vocab_size)
1452
+ dist = Exponential(lambda_p.clamp(min=1e-3))
1453
+
1454
+ # Compute event loss
1455
+ tte_loss = torch.where(
1456
+ motor_event_indicators,
1457
+ -dist.log_prob(motor_time_to_event_vectors),
1458
+ -torch.log(
1459
+ 1 - dist.cdf(motor_time_to_event_vectors).clamp(max=1 - 1e-6) + 1e-6
1460
+ ),
1461
+ )
1462
+ tte_loss = torch.where(motor_time_indicators, tte_loss, 0.0)
1463
+ return torch.mean(tte_loss)
1464
+
1331
1465
  def forward(
1332
1466
  self,
1333
1467
  input_ids: Optional[torch.LongTensor] = None,
@@ -1344,6 +1478,11 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1344
1478
  time_to_visits: Optional[torch.FloatTensor] = None,
1345
1479
  time_token_indicators: Optional[torch.BoolTensor] = None,
1346
1480
  sub_time_tokens: Optional[torch.LongTensor] = None,
1481
+ motor_time_to_event_vectors: Optional[torch.FloatTensor] = None,
1482
+ motor_event_indicators: Optional[torch.BoolTensor] = None,
1483
+ motor_time_to_event_to_include: Optional[torch.BoolTensor] = None,
1484
+ motor_time_indicators: Optional[torch.BoolTensor] = None,
1485
+ motor_end_index: Optional[torch.LongTensor] = None,
1347
1486
  use_cache: Optional[bool] = None,
1348
1487
  output_attentions: Optional[bool] = None,
1349
1488
  output_hidden_states: Optional[bool] = None,
@@ -1403,6 +1542,8 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1403
1542
  time_token_loss = None
1404
1543
  time_to_visit_loss = None
1405
1544
  token_value_loss = None
1545
+ motor_tte_loss = None
1546
+
1406
1547
  if labels is not None:
1407
1548
  # move labels to correct device to enable model parallelism
1408
1549
  labels = labels.to(lm_logits.device)
@@ -1470,9 +1611,35 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1470
1611
  entropy_penalty = entropy.sum() / total_num_tokens
1471
1612
  loss += entropy_penalty * self.cehrgpt.config.entropy_penalty_alpha
1472
1613
 
1614
+ if (
1615
+ self.config.include_motor_time_to_event
1616
+ and motor_time_to_event_vectors is not None
1617
+ and motor_event_indicators is not None
1618
+ and motor_time_to_event_to_include is not None
1619
+ and motor_time_indicators is not None
1620
+ and motor_end_index is not None
1621
+ ):
1622
+ ve_token_id_indices = labels == self.config.ve_token_id
1623
+ ve_token_features = hidden_states[ve_token_id_indices]
1624
+ # Get rid of the last VE features because it's already reached the end of the patient sequence and
1625
+ # there is nothing to predict.
1626
+ motor_tte_loss = self.motor_nll_loss(
1627
+ ve_token_features=ve_token_features,
1628
+ motor_time_to_event_vectors=motor_time_to_event_vectors,
1629
+ motor_event_indicators=motor_event_indicators,
1630
+ motor_time_to_event_to_include=motor_time_to_event_to_include,
1631
+ motor_time_indicators=motor_time_indicators,
1632
+ batch_motor_end_index=motor_end_index,
1633
+ )
1634
+ loss += motor_tte_loss * self.config.motor_time_to_event_weight
1635
+
1473
1636
  # We add another loss term when use_sub_time_tokenization is enabled, we need to recover the sub time token
1474
1637
  # predictions for year/month/token
1475
- if self.config.use_sub_time_tokenization:
1638
+ if (
1639
+ self.config.use_sub_time_tokenization
1640
+ and sub_time_tokens is not None
1641
+ and time_token_indicators is not None
1642
+ ):
1476
1643
  # Split the last dimensions into three parts
1477
1644
  time_loss_fct = CrossEntropyLoss(reduction="none")
1478
1645
  time_token_logits = self.time_token_lm_head(
@@ -1501,7 +1668,7 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1501
1668
  time_token_loss = time_token_loss.sum() / total_num_tokens
1502
1669
  loss += time_token_loss * self.config.time_token_loss_weight
1503
1670
 
1504
- if time_to_visits is not None:
1671
+ if time_to_visits is not None and time_to_visits is not None:
1505
1672
  # Get lambda and k parameters
1506
1673
  lambda_param, k_param = self.tte_head(hidden_states)
1507
1674
 
@@ -1512,14 +1679,15 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1512
1679
 
1513
1680
  # Move to the same device as lambda_param
1514
1681
  shift_time_to_visits = shift_time_to_visits.to(lambda_param.device)
1515
-
1516
1682
  time_to_visit_indicator = shift_time_to_visits >= 0
1517
1683
  # Define the Gamma distribution
1518
1684
  dist = Gamma(
1519
1685
  shifted_k_param.squeeze(-1), shifted_lambda_param.squeeze(-1)
1520
1686
  )
1521
1687
  # Compute log-probs and apply the time_to_visit_indicator
1522
- log_probs = dist.log_prob(torch.clamp(shift_time_to_visits, min=1e-3))
1688
+ log_probs = dist.log_prob(
1689
+ torch.clamp(shift_time_to_visits, min=1e-3) + 1e-6
1690
+ )
1523
1691
  log_probs = torch.where(time_to_visit_indicator, log_probs, 0)
1524
1692
  time_to_visit_loss = -log_probs.sum() / total_num_tokens
1525
1693
  # Compute the loss
@@ -1564,6 +1732,7 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1564
1732
  time_token_loss=time_token_loss,
1565
1733
  time_to_visit_loss=time_to_visit_loss,
1566
1734
  token_value_loss=token_value_loss,
1735
+ motor_tte_loss=motor_tte_loss,
1567
1736
  )
1568
1737
 
1569
1738
  @staticmethod
@@ -1681,6 +1850,7 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1681
1850
 
1682
1851
  # keep track of which sequences are already finished
1683
1852
  batch_size, cur_len = input_ids.shape
1853
+ model_kwargs["attention_mask"] = input_ids != pad_token_id
1684
1854
  if "inputs_embeds" in model_kwargs:
1685
1855
  cur_len = model_kwargs["inputs_embeds"].shape[1]
1686
1856
  this_peer_finished = False
@@ -1699,11 +1869,19 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1699
1869
  [] if self.config.lab_token_ids is None else self.config.lab_token_ids,
1700
1870
  dtype=torch.int32,
1701
1871
  )
1702
- value_indicators = torch.zeros_like(input_ids).to(torch.bool)
1703
- values = torch.zeros_like(
1704
- input_ids,
1705
- dtype=torch.int32,
1706
- )
1872
+
1873
+ if model_kwargs.get("value_indicators", None) is not None:
1874
+ value_indicators = model_kwargs.get("value_indicators")
1875
+ else:
1876
+ value_indicators = torch.zeros_like(input_ids).to(torch.bool)
1877
+
1878
+ if model_kwargs.get("values", None) is not None:
1879
+ values = model_kwargs.get("values")
1880
+ else:
1881
+ values = torch.zeros_like(
1882
+ input_ids,
1883
+ dtype=torch.int32,
1884
+ )
1707
1885
  # Generate initial random_vectors
1708
1886
  if self.cehrgpt.config.causal_sfm:
1709
1887
  model_kwargs["random_vectors"] = torch.rand(
@@ -1837,6 +2015,27 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1837
2015
  )
1838
2016
 
1839
2017
 
2018
+ class FocalLoss(nn.Module):
2019
+ def __init__(self, alpha=0.25, gamma=2.0, reduction="mean"):
2020
+ super().__init__()
2021
+ self.alpha = alpha
2022
+ self.gamma = gamma
2023
+ self.reduction = reduction
2024
+
2025
+ def forward(self, logits, targets):
2026
+ bce_loss = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")
2027
+ probs = torch.sigmoid(logits)
2028
+ pt = torch.where(targets == 1, probs, 1 - probs)
2029
+ focal_term = (1 - pt) ** self.gamma
2030
+ loss = self.alpha * focal_term * bce_loss
2031
+
2032
+ if self.reduction == "mean":
2033
+ return loss.mean()
2034
+ elif self.reduction == "sum":
2035
+ return loss.sum()
2036
+ return loss
2037
+
2038
+
1840
2039
  class CehrGptForClassification(CEHRGPTPreTrainedModel):
1841
2040
  _keep_in_fp32_modules = ["age_batch_norm", "dense_layer", "classifier"]
1842
2041
 
@@ -1859,7 +2058,6 @@ class CehrGptForClassification(CEHRGPTPreTrainedModel):
1859
2058
  self.model_parallel = False
1860
2059
  self.device_map = None
1861
2060
  self.gradient_checkpointing = False
1862
-
1863
2061
  # Initialize weights and apply final processing
1864
2062
  self.post_init()
1865
2063
 
@@ -1971,7 +2169,14 @@ class CehrGptForClassification(CEHRGPTPreTrainedModel):
1971
2169
 
1972
2170
  loss = None
1973
2171
  if classifier_label is not None:
1974
- loss_fct = nn.BCEWithLogitsLoss()
2172
+ if self.config.class_weights:
2173
+ class_weights = torch.tensor(
2174
+ [self.config.class_weights[1] / self.config.class_weights[0]],
2175
+ dtype=torch.float32,
2176
+ ).to(logits.device)
2177
+ else:
2178
+ class_weights = None
2179
+ loss_fct = nn.BCEWithLogitsLoss(pos_weight=class_weights)
1975
2180
  loss = loss_fct(logits, classifier_label)
1976
2181
 
1977
2182
  return CehrGptSequenceClassifierOutput(
@@ -85,6 +85,7 @@ class CehrGptCausalLMOutput(ModelOutput):
85
85
  time_token_loss: Optional[torch.FloatTensor] = None
86
86
  time_to_visit_loss: Optional[torch.FloatTensor] = None
87
87
  token_value_loss: Optional[torch.FloatTensor] = None
88
+ motor_tte_loss: Optional[torch.FloatTensor] = None
88
89
 
89
90
 
90
91
  @dataclass
@@ -3,6 +3,7 @@ START_TOKEN = "[START]"
3
3
  END_TOKEN = "[END]"
4
4
  PAD_TOKEN = "[PAD]"
5
5
  OUT_OF_VOCABULARY_TOKEN = "[OOV]"
6
+ LINEAR_PROB_TOKEN = "[LINEAR_PROB]"
6
7
 
7
8
  # OMOP CONCEPT IDs
8
9
  VISIT_CONCEPT_IDS = [