cehrgpt 0.0.2__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. cehrgpt/analysis/irregularity.py +36 -0
  2. cehrgpt/data/hf_cehrgpt_dataset.py +25 -4
  3. cehrgpt/data/hf_cehrgpt_dataset_collator.py +635 -97
  4. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +308 -95
  5. cehrgpt/data/sample_packing_sampler.py +181 -0
  6. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
  7. cehrgpt/generation/omop_converter_batch.py +32 -2
  8. cehrgpt/gpt_utils.py +20 -2
  9. cehrgpt/models/config.py +35 -0
  10. cehrgpt/models/hf_cehrgpt.py +470 -106
  11. cehrgpt/models/hf_modeling_outputs.py +1 -0
  12. cehrgpt/models/special_tokens.py +1 -0
  13. cehrgpt/models/tokenization_hf_cehrgpt.py +358 -71
  14. cehrgpt/runners/data_utils.py +358 -0
  15. cehrgpt/runners/gpt_runner_util.py +0 -10
  16. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +181 -283
  17. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +288 -112
  18. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +90 -0
  19. cehrgpt/runners/hyperparameter_search_util.py +10 -8
  20. cehrgpt/runners/sample_packing_trainer.py +185 -0
  21. cehrgpt/simulations/generate_plots.py +95 -0
  22. cehrgpt/simulations/run_simulation.sh +24 -0
  23. cehrgpt/simulations/time_embedding_simulation.py +250 -0
  24. cehrgpt/simulations/time_token_simulation.py +177 -0
  25. cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
  26. cehrgpt/time_to_event/time_to_event_model.py +2 -13
  27. cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
  28. cehrgpt/tools/linear_prob/__init__.py +0 -0
  29. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +495 -0
  30. cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
  31. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/METADATA +11 -8
  32. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/RECORD +36 -32
  33. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/WHEEL +1 -1
  34. cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
  35. cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
  36. cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
  37. cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
  38. cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
  39. cehrgpt/rl_finetune/ppo_finetune.py +0 -394
  40. cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
  41. cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
  42. /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
  43. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info/licenses}/LICENSE +0 -0
  44. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/top_level.txt +0 -0
@@ -6,7 +6,7 @@ import numpy as np
6
6
  import torch
7
7
  import torch.nn.functional as f
8
8
  from torch import nn
9
- from torch.distributions import Gamma
9
+ from torch.distributions import Exponential, Gamma
10
10
  from torch.nn import CrossEntropyLoss
11
11
  from torch.nn import functional as F
12
12
  from transformers import PreTrainedModel
@@ -45,12 +45,108 @@ if is_accelerate_available():
45
45
  logger = logging.get_logger(__name__)
46
46
 
47
47
 
48
+ def extract_features_from_packed_sequence(
49
+ hidden_state: torch.Tensor,
50
+ attention_mask: torch.Tensor,
51
+ ) -> torch.Tensor:
52
+ max_index = attention_mask.nonzero(as_tuple=False).flatten()[-1]
53
+ padded_attention_mask = F.pad(attention_mask[:, : max_index + 1], (0, 1))
54
+ feature_indices = torch.nonzero(padded_attention_mask == 0)[:, 1] - 1
55
+ return hidden_state[:, feature_indices]
56
+
57
+
58
+ def create_sample_packing_attention_mask(attention_mask: torch.Tensor) -> torch.Tensor:
59
+ """
60
+ Create a block-diagonal attention mask for packed sequences within a batch.
61
+
62
+ Args:
63
+ attention_mask (torch.Tensor): (batch_size, seq_len) binary mask where 1 = token, 0 = padding
64
+
65
+ Returns:
66
+ torch.Tensor: (batch_size, seq_len, seq_len) attention mask where entries are 1 if tokens
67
+ can attend to each other (within same packed segment), 0 otherwise.
68
+ """
69
+ # Step 1: Identify segments within each sample
70
+ cumsum_mask = (attention_mask == 0).cumsum(dim=-1)
71
+ segment_ids = cumsum_mask * attention_mask # zeros remain zero
72
+
73
+ # Step 2: Compare segment IDs pairwise per batch element
74
+ # Shape: (batch_size, seq_len, seq_len)
75
+ attn_matrix = (segment_ids.unsqueeze(2) == segment_ids.unsqueeze(1)).int()
76
+
77
+ # Step 3: Mask out padding tokens
78
+ mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
79
+ attn_matrix = attn_matrix * mask
80
+
81
+ return attn_matrix
82
+
83
+
84
+ def is_sample_pack(attention_mask: torch.Tensor) -> bool:
85
+ """
86
+ Determines whether any sequence in the batch is likely sample-packed.
87
+
88
+ A sample-packed sequence is one where there are non-padding (1) tokens
89
+ after a padding (0) token, indicating multiple sequences packed together
90
+ with padding as a separator.
91
+
92
+ Args:
93
+ attention_mask (torch.Tensor): A tensor of shape (batch_size, seq_len)
94
+ where 1 indicates a real token and 0 indicates padding.
95
+
96
+ Returns:
97
+ bool: True if any sample in the batch is sample-packed, False otherwise.
98
+ """
99
+
100
+ # If the attention_maks is left padded, we will flip it so we can use the same logic below
101
+ if (attention_mask[:, 0] == 0).any():
102
+ attention_mask = attention_mask.flip(dims=[1])
103
+
104
+ nonzero_counts = attention_mask.sum(dim=1)
105
+ max_token_positions = torch.argmax(attention_mask.flip(dims=[1]), dim=1)
106
+ max_indices = attention_mask.shape[1] - 1 - max_token_positions
107
+ return torch.any(nonzero_counts < (max_indices + 1)).item()
108
+
109
+
48
110
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
49
111
  def _get_unpad_data(attention_mask):
50
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
51
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
52
- max_seqlen_in_batch = seqlens_in_batch.max().item()
53
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
112
+ # This infers sample packing
113
+ if is_sample_pack(attention_mask):
114
+ # Assume input: attention_mask shape = (batch, seq_len)
115
+ attention_mask = attention_mask.flatten() # shape: (seq_len,)
116
+
117
+ # Compute max_index of the last non-zero element
118
+ nonzero = torch.nonzero(attention_mask, as_tuple=False).flatten()
119
+ max_index = nonzero[-1].item()
120
+
121
+ # Pad the truncated attention mask
122
+ padded_attention_mask = F.pad(attention_mask[: max_index + 1], (0, 1), value=0)
123
+
124
+ # Indices of all tokens
125
+ indices = torch.nonzero(attention_mask, as_tuple=False).flatten()
126
+
127
+ # Find where 0s occur (segment boundaries)
128
+ cumsum_seqlens_in_batch = torch.cumsum(padded_attention_mask, dim=0)[
129
+ padded_attention_mask == 0
130
+ ]
131
+
132
+ # Compute seqlens per segment
133
+ seqlens_in_batch = (
134
+ cumsum_seqlens_in_batch
135
+ - F.pad(cumsum_seqlens_in_batch, (1, 0), value=0)[:-1]
136
+ ).to(torch.int)
137
+
138
+ max_seqlen_in_batch = (
139
+ seqlens_in_batch.max().item() if seqlens_in_batch.numel() > 0 else 0
140
+ )
141
+ cu_seqlens = F.pad(cumsum_seqlens_in_batch, (1, 0)).to(torch.int)
142
+ else:
143
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
144
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
145
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
146
+ cu_seqlens = F.pad(
147
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
148
+ )
149
+
54
150
  return (
55
151
  indices,
56
152
  cu_seqlens,
@@ -266,9 +362,37 @@ class GPT2FlashAttention(GPT2Attention):
266
362
  )
267
363
 
268
364
 
269
- class WeibullModel(nn.Module):
365
+ class MotorTaskHead(nn.Module):
366
+ def __init__(self, input_dim, motor_tte_vocab_size, motor_num_time_pieces):
367
+ super(MotorTaskHead, self).__init__()
368
+ self.input_dim = input_dim
369
+ self.motor_tte_vocab_size = motor_tte_vocab_size
370
+ self.motor_num_time_pieces = motor_num_time_pieces
371
+ self.linear = nn.Sequential(
372
+ nn.Linear(input_dim, input_dim // 2),
373
+ gelu_new,
374
+ nn.Linear(
375
+ input_dim // 2, motor_tte_vocab_size * self.motor_num_time_pieces
376
+ ),
377
+ )
378
+
379
+ def forward(self, x):
380
+ # Ensure scale is positive
381
+ length = x.shape[0]
382
+ # (num_visits_in_batch, motor_tte_vocab_size * motor_num_time_pieces)
383
+ lambda_p = f.softplus(self.linear(x))
384
+ # Check for NaN values
385
+ if torch.isnan(lambda_p).any():
386
+ logger.warning(f"NaN values found in scale_param. x: {x}")
387
+ # (num_visits_in_batch, motor_num_time_pieces, motor_tte_vocab_size,)
388
+ return lambda_p.view(
389
+ length, self.motor_num_time_pieces, self.motor_tte_vocab_size
390
+ )
391
+
392
+
393
+ class VisitTimeToEventHead(nn.Module):
270
394
  def __init__(self, input_dim):
271
- super(WeibullModel, self).__init__()
395
+ super(VisitTimeToEventHead, self).__init__()
272
396
  self.linear1 = nn.Sequential(
273
397
  nn.Linear(input_dim, input_dim // 2), gelu_new, nn.Linear(input_dim // 2, 1)
274
398
  )
@@ -565,32 +689,33 @@ class CEHRGPTPreTrainedModel(PreTrainedModel):
565
689
  hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
566
690
  )
567
691
  wpe = self.get_position_embeddings()
568
- max_position, embed_dim = wpe.weight.shape
569
- if new_num_position_embeddings > max_position:
570
- new_embeddings = nn.Embedding(
571
- new_num_position_embeddings,
572
- embed_dim,
573
- device=wpe.weight.device,
574
- dtype=wpe.weight.dtype,
575
- )
576
-
577
- # initialize all new embeddings (in particular added tokens)
578
- self._init_weights(new_embeddings)
579
- if is_deepspeed_zero3_enabled() and not is_quantized:
580
- import deepspeed
692
+ if wpe is not None:
693
+ max_position, embed_dim = wpe.weight.shape
694
+ if new_num_position_embeddings > max_position:
695
+ new_embeddings = nn.Embedding(
696
+ new_num_position_embeddings,
697
+ embed_dim,
698
+ device=wpe.weight.device,
699
+ dtype=wpe.weight.dtype,
700
+ )
581
701
 
582
- params = [wpe.weight, new_embeddings.weight]
583
- with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
702
+ # initialize all new embeddings (in particular added tokens)
703
+ self._init_weights(new_embeddings)
704
+ if is_deepspeed_zero3_enabled() and not is_quantized:
705
+ import deepspeed
706
+
707
+ params = [wpe.weight, new_embeddings.weight]
708
+ with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
709
+ new_embeddings.weight.data[:max_position, :] = (
710
+ wpe.weight.data[:max_position, :]
711
+ )
712
+ else:
584
713
  new_embeddings.weight.data[:max_position, :] = wpe.weight.data[
585
714
  :max_position, :
586
715
  ]
587
- else:
588
- new_embeddings.weight.data[:max_position, :] = wpe.weight.data[
589
- :max_position, :
590
- ]
591
- self.set_position_embeddings(new_embeddings)
592
- self.config.max_position_embeddings = new_num_position_embeddings
593
- self.update_attn_bias(new_num_position_embeddings)
716
+ self.set_position_embeddings(new_embeddings)
717
+ self.config.max_position_embeddings = new_num_position_embeddings
718
+ self.update_attn_bias(new_num_position_embeddings)
594
719
 
595
720
 
596
721
  class CEHRGPT2Model(CEHRGPTPreTrainedModel):
@@ -609,7 +734,8 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
609
734
  self.pretrained_wte = None
610
735
 
611
736
  self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
612
- self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
737
+ if not self.exclude_position_ids:
738
+ self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
613
739
  if self.include_values:
614
740
  self.vte = nn.Embedding(config.value_vocab_size, self.embed_dim)
615
741
  self.concept_value_transformation_layer = ConceptValueTransformationLayer(
@@ -635,6 +761,18 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
635
761
  # Initialize weights and apply final processing
636
762
  self.post_init()
637
763
 
764
+ # We do need to update the pre-computed attention bias matrix if sample packing requires a larger context window
765
+ if self.config.sample_packing_max_positions > self.config.n_positions:
766
+ logger.info(
767
+ "Updated attn_bias to %s according to sample_packing_max_positions",
768
+ config.sample_packing_max_positions,
769
+ )
770
+ self.update_attn_bias(self.config.sample_packing_max_positions)
771
+
772
+ def enable_position_embeddings(self):
773
+ self.wpe = nn.Embedding(self.config.max_position_embeddings, self.embed_dim)
774
+ self.config.exclude_position_ids = False
775
+
638
776
  def initialize_pretrained_embeddings(self):
639
777
  layers = [
640
778
  nn.Embedding(self.config.vocab_size, self.config.pretrained_embedding_dim),
@@ -677,7 +815,8 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
677
815
  self.wte = self.wte.to(self.first_device)
678
816
  if self.config.use_pretrained_embeddings:
679
817
  self.pretrained_wte = self.pretrained_wte.to(self.first_device)
680
- self.wpe = self.wpe.to(self.first_device)
818
+ if not self.exclude_position_ids:
819
+ self.wpe = self.wpe.to(self.first_device)
681
820
  if self.include_values:
682
821
  self.vte = self.vte.to(self.first_device)
683
822
  self.concept_value_transformation_layer = (
@@ -703,7 +842,8 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
703
842
  self.wte = self.wte.to("cpu")
704
843
  if self.config.use_pretrained_embeddings:
705
844
  self.pretrained_wte = self.pretrained_wte.to("cpu")
706
- self.wpe = self.wpe.to("cpu")
845
+ if not self.exclude_position_ids:
846
+ self.wpe = self.wpe.to("cpu")
707
847
  self.vte = self.vte.to("cpu")
708
848
  self.concept_value_transformation_layer = (
709
849
  self.concept_value_transformation_layer.to("cpu")
@@ -728,8 +868,12 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
728
868
  persistent=False,
729
869
  )
730
870
 
731
- def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]:
732
- return self.wpe
871
+ def get_position_embeddings(
872
+ self,
873
+ ) -> Optional[Union[nn.Embedding, Tuple[nn.Embedding]]]:
874
+ if not self.exclude_position_ids:
875
+ return self.wpe
876
+ return None
733
877
 
734
878
  def set_position_embeddings(self, new_embeddings: nn.Embedding):
735
879
  self.wpe = new_embeddings
@@ -758,8 +902,8 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
758
902
  def forward(
759
903
  self,
760
904
  input_ids: Optional[torch.LongTensor],
761
- value_indicators: Optional[torch.BoolTensor],
762
- values: Optional[torch.LongTensor],
905
+ value_indicators: Optional[torch.BoolTensor] = None,
906
+ values: Optional[torch.LongTensor] = None,
763
907
  past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
764
908
  attention_mask: Optional[torch.FloatTensor] = None,
765
909
  position_ids: Optional[torch.LongTensor] = None,
@@ -850,12 +994,19 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
850
994
  == "flash_attention_2"
851
995
  ):
852
996
  attention_mask = attention_mask.view(batch_size, -1)
853
- # We create a 3D attention mask from a 2D tensor mask.
854
- # Sizes are [batch_size, 1, 1, to_seq_length]
855
- # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
856
- # this attention mask is more simple than the triangular masking of causal attention
857
- # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
858
- attention_mask = attention_mask[:, None, None, :]
997
+
998
+ # If this is sample packing, we need to great the
999
+ if is_sample_pack(attention_mask):
1000
+ attention_mask = create_sample_packing_attention_mask(
1001
+ attention_mask
1002
+ )[:, None, :, :]
1003
+ else:
1004
+ # We create a 3D attention mask from a 2D tensor mask.
1005
+ # Sizes are [batch_size, 1, 1, to_seq_length]
1006
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
1007
+ # this attention mask is more simple than the triangular masking of causal attention
1008
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
1009
+ attention_mask = attention_mask[:, None, None, :]
859
1010
 
860
1011
  # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
861
1012
  # masked positions, this operation will create a tensor which is 0.0 for
@@ -925,7 +1076,7 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
925
1076
  )
926
1077
 
927
1078
  if not self.exclude_position_ids:
928
- position_embeds = self.wpe(position_ids)
1079
+ position_embeds = self.wpe(position_ids).to(input_embeddings.dtype)
929
1080
  hidden_states = input_embeddings + position_embeds
930
1081
  else:
931
1082
  hidden_states = input_embeddings
@@ -1034,7 +1185,7 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1034
1185
  super().__init__(config)
1035
1186
  self.cehrgpt = CEHRGPT2Model(config)
1036
1187
  if self.config.include_ttv_prediction:
1037
- self.tte_head = WeibullModel(config.n_embd)
1188
+ self.tte_head = VisitTimeToEventHead(config.n_embd)
1038
1189
 
1039
1190
  if self.config.use_sub_time_tokenization:
1040
1191
  self.time_token_lm_head = nn.Linear(
@@ -1047,6 +1198,11 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1047
1198
  config.n_embd, config.value_vocab_size, bias=False
1048
1199
  )
1049
1200
 
1201
+ if self.config.include_motor_time_to_event:
1202
+ self.motor_tte = MotorTaskHead(
1203
+ config.n_embd, config.motor_tte_vocab_size, config.motor_num_time_pieces
1204
+ )
1205
+
1050
1206
  # Model parallel
1051
1207
  self.model_parallel = False
1052
1208
  self.device_map = None
@@ -1074,6 +1230,8 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1074
1230
  self.value_head = self.value_head.to(self.cehrgpt.first_device)
1075
1231
  if self.config.include_ttv_prediction:
1076
1232
  self.tte_head = self.tte_head.to(self.cehrgpt.first_device)
1233
+ if self.config.include_motor_time_to_event:
1234
+ self.motor_tte = self.motor_tte.to(self.cehrgpt.first_device)
1077
1235
  self.model_parallel = True
1078
1236
 
1079
1237
  def deparallelize(self):
@@ -1088,6 +1246,8 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1088
1246
  self.value_head = self.value_head.to("cpu")
1089
1247
  if self.config.include_ttv_prediction:
1090
1248
  self.tte_head = self.tte_head.to("cpu")
1249
+ if self.config.include_motor_time_to_event:
1250
+ self.motor_tte = self.motor_tte.to("cpu")
1091
1251
  self.model_parallel = False
1092
1252
  torch.cuda.empty_cache()
1093
1253
 
@@ -1115,6 +1275,28 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1115
1275
  def update_attn_bias(self, max_position_embeddings: int):
1116
1276
  self.cehrgpt.update_attn_bias(max_position_embeddings)
1117
1277
 
1278
+ def update_motor_tte_vocab_size(
1279
+ self, motor_tte_vocab_size: Optional[int] = None
1280
+ ) -> None:
1281
+ update_motor_tte_layer = False
1282
+ if motor_tte_vocab_size and motor_tte_vocab_size > 0:
1283
+ if self.config.include_motor_time_to_event:
1284
+ if self.config.motor_tte_vocab_size != motor_tte_vocab_size:
1285
+ self.config.include_motor_time_to_event = True
1286
+ self.config.motor_tte_vocab_size = motor_tte_vocab_size
1287
+ update_motor_tte_layer = True
1288
+ else:
1289
+ self.config.include_motor_time_to_event = True
1290
+ self.config.motor_tte_vocab_size = motor_tte_vocab_size
1291
+ update_motor_tte_layer = True
1292
+
1293
+ if update_motor_tte_layer:
1294
+ self.motor_tte = MotorTaskHead(
1295
+ self.config.n_embd,
1296
+ self.config.motor_tte_vocab_size,
1297
+ self.config.motor_num_time_pieces,
1298
+ )
1299
+
1118
1300
  def prepare_inputs_for_generation(
1119
1301
  self,
1120
1302
  input_ids,
@@ -1210,6 +1392,74 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1210
1392
 
1211
1393
  return model_inputs
1212
1394
 
1395
+ def motor_nll_loss(
1396
+ self,
1397
+ ve_token_features,
1398
+ motor_time_to_event_vectors,
1399
+ motor_event_indicators,
1400
+ motor_time_to_event_to_include,
1401
+ motor_time_indicators,
1402
+ batch_motor_end_index,
1403
+ ):
1404
+ """
1405
+ Computes the negative log-likelihood (NLL) loss using the LogNormal distribution.
1406
+
1407
+ for modeling time-to-event data at each visit.
1408
+
1409
+ Args:
1410
+ ve_token_features (Tensor): Hidden representations for the [VE] tokens [num_visits, hidden_dim].
1411
+ motor_time_to_event_vectors (Tensor): Raw time-to-event durations [B, T, motor_vocab_size] (flattened).
1412
+ motor_time_to_event_to_include: (Tensor): Bool indicators (True if included, False if not included).
1413
+ motor_event_indicators (Tensor): Binary indicators (1 if censored, 0 if event occurred).
1414
+ motor_time_indicators (Tensor): Binary indicators whether the time occurs in the current
1415
+ time bucket (1 if censored, 0 if event occurred).
1416
+ batch_motor_end_index (Tensor): Tensor indicating the number of valid [VE] tokens in the batch.
1417
+
1418
+ Returns:
1419
+ Tensor: Scalar loss value (mean negative log-likelihood).
1420
+ """
1421
+ batch_motor_end_index = batch_motor_end_index.sum().item()
1422
+ motor_time_to_event_vectors = motor_time_to_event_vectors.view(
1423
+ (-1, self.config.motor_num_time_pieces, self.config.motor_tte_vocab_size)
1424
+ )[:batch_motor_end_index].clamp(min=1e-3)
1425
+ motor_event_indicators = motor_event_indicators.reshape(
1426
+ (-1, self.config.motor_num_time_pieces, self.config.motor_tte_vocab_size)
1427
+ )[:batch_motor_end_index]
1428
+ motor_time_to_event_to_include = motor_time_to_event_to_include.flatten()[
1429
+ :batch_motor_end_index
1430
+ ]
1431
+ motor_time_indicators = motor_time_indicators.view(
1432
+ (-1, self.config.motor_num_time_pieces, self.config.motor_tte_vocab_size)
1433
+ )[:batch_motor_end_index]
1434
+ assert ve_token_features.shape[0] == motor_time_to_event_vectors.shape[0], (
1435
+ "The number of VE tokens in the labels needs to match up "
1436
+ "with the first dimension of motor_time_to_event_vectors. "
1437
+ f"Received ve_token_features.shape[0]: {ve_token_features.shape[0]}, "
1438
+ f"motor_time_to_event_vectors.shape[0]: {motor_time_to_event_vectors.shape[0]}"
1439
+ )
1440
+ motor_time_to_event_vectors = motor_time_to_event_vectors[
1441
+ motor_time_to_event_to_include
1442
+ ]
1443
+ motor_event_indicators = motor_event_indicators[motor_time_to_event_to_include]
1444
+ motor_time_indicators = motor_time_indicators[motor_time_to_event_to_include]
1445
+ ve_token_features = ve_token_features[motor_time_to_event_to_include]
1446
+
1447
+ # Get Exponential parameters from model
1448
+ lambda_p = self.motor_tte(ve_token_features)
1449
+ # (num_visits_in_batch, num_of_pieces, motor_vocab_size)
1450
+ dist = Exponential(lambda_p.clamp(min=1e-3))
1451
+
1452
+ # Compute event loss
1453
+ tte_loss = torch.where(
1454
+ motor_event_indicators,
1455
+ -dist.log_prob(motor_time_to_event_vectors),
1456
+ -torch.log(
1457
+ 1 - dist.cdf(motor_time_to_event_vectors).clamp(max=1 - 1e-6) + 1e-6
1458
+ ),
1459
+ )
1460
+ tte_loss = torch.where(motor_time_indicators, tte_loss, 0.0)
1461
+ return torch.mean(tte_loss)
1462
+
1213
1463
  def forward(
1214
1464
  self,
1215
1465
  input_ids: Optional[torch.LongTensor] = None,
@@ -1226,6 +1476,11 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1226
1476
  time_to_visits: Optional[torch.FloatTensor] = None,
1227
1477
  time_token_indicators: Optional[torch.BoolTensor] = None,
1228
1478
  sub_time_tokens: Optional[torch.LongTensor] = None,
1479
+ motor_time_to_event_vectors: Optional[torch.FloatTensor] = None,
1480
+ motor_event_indicators: Optional[torch.BoolTensor] = None,
1481
+ motor_time_to_event_to_include: Optional[torch.BoolTensor] = None,
1482
+ motor_time_indicators: Optional[torch.BoolTensor] = None,
1483
+ motor_end_index: Optional[torch.LongTensor] = None,
1229
1484
  use_cache: Optional[bool] = None,
1230
1485
  output_attentions: Optional[bool] = None,
1231
1486
  output_hidden_states: Optional[bool] = None,
@@ -1285,12 +1540,31 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1285
1540
  time_token_loss = None
1286
1541
  time_to_visit_loss = None
1287
1542
  token_value_loss = None
1543
+ motor_tte_loss = None
1544
+
1288
1545
  if labels is not None:
1289
1546
  # move labels to correct device to enable model parallelism
1290
1547
  labels = labels.to(lm_logits.device)
1548
+
1549
+ if self.config.causal_sfm:
1550
+ # Ensure demographic_labels matches the dtype of original labels
1551
+ demographic_labels = torch.full(
1552
+ (labels.shape[0], self.config.demographics_size),
1553
+ -100,
1554
+ dtype=labels.dtype, # Match the original labels' dtype
1555
+ device=labels.device, # Ensure on the same device
1556
+ )
1557
+ # Concatenate the demographic labels with the rest of the original labels
1558
+ labels = torch.cat(
1559
+ (demographic_labels, labels[:, self.config.demographics_size :]),
1560
+ dim=1,
1561
+ )
1562
+
1291
1563
  # Shift so that tokens < n predict n
1292
1564
  shift_logits = lm_logits[..., :-1, :].contiguous()
1293
1565
  shift_labels = labels[..., 1:].contiguous()
1566
+ valid_tokens: torch.BoolTensor = shift_labels != 100
1567
+ total_num_tokens = valid_tokens.sum()
1294
1568
  if (
1295
1569
  self.cehrgpt.config.lab_token_penalty
1296
1570
  and self.cehrgpt.config.lab_token_exists
@@ -1310,28 +1584,60 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1310
1584
  lab_index,
1311
1585
  token_loss * self.cehrgpt.config.lab_token_loss_weight,
1312
1586
  token_loss,
1313
- ).mean()
1587
+ )
1588
+
1589
+ token_loss = token_loss.sum() / total_num_tokens
1314
1590
  else:
1315
1591
  # Flatten the tokens
1316
- loss_fct = CrossEntropyLoss()
1592
+ loss_fct = CrossEntropyLoss(reduction="none")
1317
1593
  token_loss = loss_fct(
1318
1594
  shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
1319
1595
  )
1320
- loss = token_loss
1596
+ token_loss = token_loss.sum() / total_num_tokens
1597
+
1598
+ loss = token_loss * self.cehrgpt.config.next_token_prediction_loss_weight
1321
1599
 
1322
1600
  if self.cehrgpt.config.entropy_penalty:
1323
1601
  # Compute probabilities using softmax
1324
- probs = torch.softmax(lm_logits, dim=-1)
1602
+ probs = torch.softmax(shift_logits, dim=-1)
1325
1603
  # Compute negative entropy: sum(p * log(p))
1326
1604
  entropy = torch.sum(
1327
1605
  probs * torch.log(probs + 1e-9), dim=-1
1328
1606
  ) # Add epsilon for numerical stability
1607
+ entropy = torch.where(valid_tokens, entropy, 0)
1329
1608
  # Regularization term: mean entropy scaled by alpha
1330
- loss += self.cehrgpt.config.entropy_penalty_alpha * entropy.mean()
1609
+ entropy_penalty = entropy.sum() / total_num_tokens
1610
+ loss += entropy_penalty * self.cehrgpt.config.entropy_penalty_alpha
1611
+
1612
+ if (
1613
+ self.config.include_motor_time_to_event
1614
+ and motor_time_to_event_vectors is not None
1615
+ and motor_event_indicators is not None
1616
+ and motor_time_to_event_to_include is not None
1617
+ and motor_time_indicators is not None
1618
+ and motor_end_index is not None
1619
+ ):
1620
+ ve_token_id_indices = labels == self.config.ve_token_id
1621
+ ve_token_features = hidden_states[ve_token_id_indices]
1622
+ # Get rid of the last VE features because it's already reached the end of the patient sequence and
1623
+ # there is nothing to predict.
1624
+ motor_tte_loss = self.motor_nll_loss(
1625
+ ve_token_features=ve_token_features,
1626
+ motor_time_to_event_vectors=motor_time_to_event_vectors,
1627
+ motor_event_indicators=motor_event_indicators,
1628
+ motor_time_to_event_to_include=motor_time_to_event_to_include,
1629
+ motor_time_indicators=motor_time_indicators,
1630
+ batch_motor_end_index=motor_end_index,
1631
+ )
1632
+ loss += motor_tte_loss * self.config.motor_time_to_event_weight
1331
1633
 
1332
1634
  # We add another loss term when use_sub_time_tokenization is enabled, we need to recover the sub time token
1333
1635
  # predictions for year/month/token
1334
- if self.config.use_sub_time_tokenization:
1636
+ if (
1637
+ self.config.use_sub_time_tokenization
1638
+ and sub_time_tokens is not None
1639
+ and time_token_indicators is not None
1640
+ ):
1335
1641
  # Split the last dimensions into three parts
1336
1642
  time_loss_fct = CrossEntropyLoss(reduction="none")
1337
1643
  time_token_logits = self.time_token_lm_head(
@@ -1352,54 +1658,61 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1352
1658
  ),
1353
1659
  shifted_time_token_labels.view(-1),
1354
1660
  )
1355
-
1356
- time_token_loss = time_token_loss.view(
1357
- -1, 3
1358
- ) * shifted_time_token_indicators.view(-1, 1).to(hidden_states.dtype)
1359
- time_token_loss = time_token_loss.sum(-1)
1360
- time_token_loss = (
1361
- torch.mean(time_token_loss) * self.config.time_token_loss_weight
1661
+ time_token_loss = torch.where(
1662
+ shifted_time_token_indicators.view(-1, 1).to(torch.bool),
1663
+ time_token_loss.view(-1, 3),
1664
+ 0,
1362
1665
  )
1363
- loss += time_token_loss
1364
-
1365
- if time_to_visits is not None:
1366
- # Get lambda and k parameters
1367
- lambda_param, k_param = self.tte_head(hidden_states)
1368
-
1369
- # Perform slicing before tensors are split across GPUs
1370
- shifted_lambda_param = lambda_param[..., :-1, :].contiguous()
1371
- shifted_k_param = k_param[..., :-1, :].contiguous()
1372
- shift_time_to_visits = time_to_visits[..., 1:].contiguous()
1373
-
1374
- # Move to the same device as lambda_param
1375
- shift_time_to_visits = shift_time_to_visits.to(lambda_param.device)
1376
-
1377
- time_to_visit_indicator = (shift_time_to_visits >= 0).to(
1378
- hidden_states.dtype
1379
- )
1380
- # Define the Gamma distribution
1381
- dist = Gamma(shifted_k_param.squeeze(-1), shifted_lambda_param.squeeze(-1))
1382
- # Compute log-probs and apply the time_to_visit_indicator
1383
- log_probs = dist.log_prob(torch.clamp(shift_time_to_visits, min=0.0) + 1e-6)
1384
- log_probs *= time_to_visit_indicator
1385
- time_to_visit_loss = (
1386
- -log_probs.mean() * self.config.time_to_visit_loss_weight
1387
- )
1388
- # Compute the loss
1389
- loss += time_to_visit_loss
1390
-
1391
- if true_values is not None and true_value_indicators is not None:
1392
- true_values = true_values.to(value_logits.device)
1393
- shift_value_logits = value_logits[..., :-1, :].contiguous()
1394
- shift_value_indicators = true_value_indicators[..., :-1].contiguous()
1395
- shift_next_values = true_values[..., 1:].contiguous()
1396
- value_loss_fct = CrossEntropyLoss(reduce=False)
1397
- token_value_loss = value_loss_fct(
1398
- shift_value_logits.view(-1, shift_value_logits.size(-1)),
1399
- shift_next_values.view(-1),
1400
- )
1401
- token_value_loss *= shift_value_indicators.view(-1)
1402
- loss += token_value_loss.mean()
1666
+ time_token_loss = time_token_loss.sum() / total_num_tokens
1667
+ loss += time_token_loss * self.config.time_token_loss_weight
1668
+
1669
+ if time_to_visits is not None and time_to_visits is not None:
1670
+ # Get lambda and k parameters
1671
+ lambda_param, k_param = self.tte_head(hidden_states)
1672
+
1673
+ # Perform slicing before tensors are split across GPUs
1674
+ shifted_lambda_param = lambda_param[..., :-1, :].contiguous()
1675
+ shifted_k_param = k_param[..., :-1, :].contiguous()
1676
+ shift_time_to_visits = time_to_visits[..., 1:].contiguous()
1677
+
1678
+ # Move to the same device as lambda_param
1679
+ shift_time_to_visits = shift_time_to_visits.to(lambda_param.device)
1680
+ time_to_visit_indicator = shift_time_to_visits >= 0
1681
+ # Define the Gamma distribution
1682
+ dist = Gamma(
1683
+ shifted_k_param.squeeze(-1), shifted_lambda_param.squeeze(-1)
1684
+ )
1685
+ # Compute log-probs and apply the time_to_visit_indicator
1686
+ log_probs = dist.log_prob(
1687
+ torch.clamp(shift_time_to_visits, min=1e-3) + 1e-6
1688
+ )
1689
+ log_probs = torch.where(time_to_visit_indicator, log_probs, 0)
1690
+ time_to_visit_loss = -log_probs.sum() / total_num_tokens
1691
+ # Compute the loss
1692
+ loss += time_to_visit_loss * self.config.time_to_visit_loss_weight
1693
+
1694
+ if true_values is not None and true_value_indicators is not None:
1695
+ true_values = true_values.to(value_logits.device)
1696
+ shift_value_logits = value_logits[..., :-1, :].contiguous()
1697
+ shift_value_indicators = true_value_indicators[..., :-1].contiguous()
1698
+ shift_next_values = true_values[..., 1:].contiguous()
1699
+ value_loss_fct = CrossEntropyLoss(reduction="none")
1700
+ token_value_loss = value_loss_fct(
1701
+ shift_value_logits.view(-1, shift_value_logits.size(-1)),
1702
+ shift_next_values.view(-1),
1703
+ )
1704
+ token_value_loss = torch.where(
1705
+ shift_value_indicators.view(-1), token_value_loss, 0
1706
+ )
1707
+ token_value_loss = token_value_loss.sum() / total_num_tokens
1708
+ if (
1709
+ self.cehrgpt.config.lab_token_penalty
1710
+ and self.cehrgpt.config.lab_token_exists
1711
+ ):
1712
+ token_value_loss = (
1713
+ token_value_loss * self.config.lab_token_loss_weight
1714
+ )
1715
+ loss += token_value_loss * self.config.value_prediction_loss_weight
1403
1716
 
1404
1717
  if not return_dict:
1405
1718
  output = (lm_logits,) + transformer_outputs[1:]
@@ -1417,6 +1730,7 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1417
1730
  time_token_loss=time_token_loss,
1418
1731
  time_to_visit_loss=time_to_visit_loss,
1419
1732
  token_value_loss=token_value_loss,
1733
+ motor_tte_loss=motor_tte_loss,
1420
1734
  )
1421
1735
 
1422
1736
  @staticmethod
@@ -1690,6 +2004,27 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1690
2004
  )
1691
2005
 
1692
2006
 
2007
+ class FocalLoss(nn.Module):
2008
+ def __init__(self, alpha=0.25, gamma=2.0, reduction="mean"):
2009
+ super().__init__()
2010
+ self.alpha = alpha
2011
+ self.gamma = gamma
2012
+ self.reduction = reduction
2013
+
2014
+ def forward(self, logits, targets):
2015
+ bce_loss = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")
2016
+ probs = torch.sigmoid(logits)
2017
+ pt = torch.where(targets == 1, probs, 1 - probs)
2018
+ focal_term = (1 - pt) ** self.gamma
2019
+ loss = self.alpha * focal_term * bce_loss
2020
+
2021
+ if self.reduction == "mean":
2022
+ return loss.mean()
2023
+ elif self.reduction == "sum":
2024
+ return loss.sum()
2025
+ return loss
2026
+
2027
+
1693
2028
  class CehrGptForClassification(CEHRGPTPreTrainedModel):
1694
2029
  _keep_in_fp32_modules = ["age_batch_norm", "dense_layer", "classifier"]
1695
2030
 
@@ -1712,7 +2047,6 @@ class CehrGptForClassification(CEHRGPTPreTrainedModel):
1712
2047
  self.model_parallel = False
1713
2048
  self.device_map = None
1714
2049
  self.gradient_checkpointing = False
1715
-
1716
2050
  # Initialize weights and apply final processing
1717
2051
  self.post_init()
1718
2052
 
@@ -1768,6 +2102,7 @@ class CehrGptForClassification(CEHRGPTPreTrainedModel):
1768
2102
  return_dict: Optional[bool] = None,
1769
2103
  **kwargs,
1770
2104
  ) -> CehrGptSequenceClassifierOutput:
2105
+
1771
2106
  cehrgpt_output = self.cehrgpt(
1772
2107
  input_ids=input_ids,
1773
2108
  value_indicators=value_indicators,
@@ -1782,17 +2117,39 @@ class CehrGptForClassification(CEHRGPTPreTrainedModel):
1782
2117
  return_dict=return_dict,
1783
2118
  )
1784
2119
 
1785
- # Disable autocasting for precision-sensitive operations
1786
- with torch.autocast(device_type="cuda", enabled=False):
1787
- normalized_age = self._apply_age_norm(age_at_index)
2120
+ if is_sample_pack(attention_mask):
2121
+ features = extract_features_from_packed_sequence(
2122
+ cehrgpt_output.last_hidden_state, attention_mask
2123
+ )
2124
+ assert features.shape[1] == classifier_label.shape[1], (
2125
+ "the length of the features need to be the same as the length of classifier_label. "
2126
+ f"features.shape[1]: {features.shape[1]}, "
2127
+ f"classifier_label.shape[1]: {classifier_label.shape[1]}"
2128
+ )
2129
+ assert features.shape[1] == age_at_index.shape[1], (
2130
+ "the length of the features need to be the same as the length of age_at_index. "
2131
+ f"features.shape[1]: {features.shape[1]}, "
2132
+ f"age_at_index.shape[1]: {age_at_index.shape[1]}"
2133
+ )
2134
+ num_samples = age_at_index.shape[1]
2135
+ features = features.view((num_samples, -1))
2136
+ classifier_label = classifier_label.view((num_samples, -1))
2137
+ with torch.autocast(device_type="cuda", enabled=False):
2138
+ normalized_age = self._apply_age_norm(
2139
+ age_at_index.view((num_samples, 1))
2140
+ )
2141
+ else:
2142
+ features = cehrgpt_output.last_hidden_state[..., -1, :]
2143
+ # Disable autocasting for precision-sensitive operations
2144
+ with torch.autocast(device_type="cuda", enabled=False):
2145
+ normalized_age = self._apply_age_norm(age_at_index)
1788
2146
 
1789
2147
  # In case the model is in bfloat16
1790
- if cehrgpt_output.last_hidden_state.dtype != normalized_age.dtype:
1791
- normalized_age = normalized_age.to(cehrgpt_output.last_hidden_state.dtype)
2148
+ if features.dtype != normalized_age.dtype:
2149
+ normalized_age = normalized_age.to(features.dtype)
1792
2150
 
1793
2151
  # In fine-tuning, the sequences are left-padded, so we use the last element as the pooler
1794
- output_pooler = cehrgpt_output.last_hidden_state[..., -1, :]
1795
- next_input = self.dropout(output_pooler)
2152
+ next_input = self.dropout(features)
1796
2153
  next_input = torch.cat([next_input, normalized_age], dim=1)
1797
2154
  next_input = self.dense_layer(next_input)
1798
2155
  next_input = nn.functional.relu(next_input)
@@ -1801,7 +2158,14 @@ class CehrGptForClassification(CEHRGPTPreTrainedModel):
1801
2158
 
1802
2159
  loss = None
1803
2160
  if classifier_label is not None:
1804
- loss_fct = nn.BCEWithLogitsLoss()
2161
+ if self.config.class_weights:
2162
+ class_weights = torch.tensor(
2163
+ [self.config.class_weights[1] / self.config.class_weights[0]],
2164
+ dtype=torch.float32,
2165
+ ).to(logits.device)
2166
+ else:
2167
+ class_weights = None
2168
+ loss_fct = nn.BCEWithLogitsLoss(pos_weight=class_weights)
1805
2169
  loss = loss_fct(logits, classifier_label)
1806
2170
 
1807
2171
  return CehrGptSequenceClassifierOutput(