cehrgpt 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. cehrgpt/data/hf_cehrgpt_dataset.py +24 -4
  2. cehrgpt/data/hf_cehrgpt_dataset_collator.py +260 -84
  3. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +99 -88
  4. cehrgpt/data/sample_packing_sampler.py +151 -0
  5. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
  6. cehrgpt/models/config.py +10 -0
  7. cehrgpt/models/hf_cehrgpt.py +243 -73
  8. cehrgpt/models/tokenization_hf_cehrgpt.py +4 -0
  9. cehrgpt/runners/data_utils.py +243 -0
  10. cehrgpt/runners/gpt_runner_util.py +0 -10
  11. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +152 -279
  12. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +229 -105
  13. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +42 -0
  14. cehrgpt/runners/hyperparameter_search_util.py +4 -1
  15. cehrgpt/runners/sample_packing_trainer.py +168 -0
  16. cehrgpt/simulations/generate_plots.py +95 -0
  17. cehrgpt/simulations/run_simulation.sh +24 -0
  18. cehrgpt/simulations/time_embedding_simulation.py +250 -0
  19. cehrgpt/simulations/time_token_simulation.py +177 -0
  20. cehrgpt/tools/linear_prob/__init__.py +0 -0
  21. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +467 -0
  22. cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
  23. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/METADATA +7 -5
  24. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/RECORD +28 -26
  25. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/WHEEL +1 -1
  26. cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
  27. cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
  28. cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
  29. cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
  30. cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
  31. cehrgpt/rl_finetune/ppo_finetune.py +0 -394
  32. cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
  33. cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
  34. /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
  35. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info/licenses}/LICENSE +0 -0
  36. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/top_level.txt +0 -0
@@ -6,7 +6,7 @@ import numpy as np
6
6
  import torch
7
7
  import torch.nn.functional as f
8
8
  from torch import nn
9
- from torch.distributions import Gamma
9
+ from torch.distributions import Gamma, Weibull
10
10
  from torch.nn import CrossEntropyLoss
11
11
  from torch.nn import functional as F
12
12
  from transformers import PreTrainedModel
@@ -45,12 +45,108 @@ if is_accelerate_available():
45
45
  logger = logging.get_logger(__name__)
46
46
 
47
47
 
48
+ def extract_features_from_packed_sequence(
49
+ hidden_state: torch.Tensor,
50
+ attention_mask: torch.Tensor,
51
+ ) -> torch.Tensor:
52
+ max_index = attention_mask.nonzero(as_tuple=False).flatten()[-1]
53
+ padded_attention_mask = F.pad(attention_mask[:, : max_index + 1], (0, 1))
54
+ feature_indices = torch.nonzero(padded_attention_mask == 0)[:, 1] - 1
55
+ return hidden_state[:, feature_indices]
56
+
57
+
58
+ def create_sample_packing_attention_mask(attention_mask: torch.Tensor) -> torch.Tensor:
59
+ """
60
+ Create a block-diagonal attention mask for packed sequences within a batch.
61
+
62
+ Args:
63
+ attention_mask (torch.Tensor): (batch_size, seq_len) binary mask where 1 = token, 0 = padding
64
+
65
+ Returns:
66
+ torch.Tensor: (batch_size, seq_len, seq_len) attention mask where entries are 1 if tokens
67
+ can attend to each other (within same packed segment), 0 otherwise.
68
+ """
69
+ # Step 1: Identify segments within each sample
70
+ cumsum_mask = (attention_mask == 0).cumsum(dim=-1)
71
+ segment_ids = cumsum_mask * attention_mask # zeros remain zero
72
+
73
+ # Step 2: Compare segment IDs pairwise per batch element
74
+ # Shape: (batch_size, seq_len, seq_len)
75
+ attn_matrix = (segment_ids.unsqueeze(2) == segment_ids.unsqueeze(1)).int()
76
+
77
+ # Step 3: Mask out padding tokens
78
+ mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
79
+ attn_matrix = attn_matrix * mask
80
+
81
+ return attn_matrix
82
+
83
+
84
+ def is_sample_pack(attention_mask: torch.Tensor) -> bool:
85
+ """
86
+ Determines whether any sequence in the batch is likely sample-packed.
87
+
88
+ A sample-packed sequence is one where there are non-padding (1) tokens
89
+ after a padding (0) token, indicating multiple sequences packed together
90
+ with padding as a separator.
91
+
92
+ Args:
93
+ attention_mask (torch.Tensor): A tensor of shape (batch_size, seq_len)
94
+ where 1 indicates a real token and 0 indicates padding.
95
+
96
+ Returns:
97
+ bool: True if any sample in the batch is sample-packed, False otherwise.
98
+ """
99
+
100
+ # If the attention_maks is left padded, we will flip it so we can use the same logic below
101
+ if (attention_mask[:, 0] == 0).any():
102
+ attention_mask = attention_mask.flip(dims=[1])
103
+
104
+ nonzero_counts = attention_mask.sum(dim=1)
105
+ max_token_positions = torch.argmax(attention_mask.flip(dims=[1]), dim=1)
106
+ max_indices = attention_mask.shape[1] - 1 - max_token_positions
107
+ return torch.any(nonzero_counts < (max_indices + 1)).item()
108
+
109
+
48
110
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
49
111
  def _get_unpad_data(attention_mask):
50
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
51
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
52
- max_seqlen_in_batch = seqlens_in_batch.max().item()
53
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
112
+ # This infers sample packing
113
+ if is_sample_pack(attention_mask):
114
+ # Assume input: attention_mask shape = (batch, seq_len)
115
+ attention_mask = attention_mask.flatten() # shape: (seq_len,)
116
+
117
+ # Compute max_index of the last non-zero element
118
+ nonzero = torch.nonzero(attention_mask, as_tuple=False).flatten()
119
+ max_index = nonzero[-1].item()
120
+
121
+ # Pad the truncated attention mask
122
+ padded_attention_mask = F.pad(attention_mask[: max_index + 1], (0, 1), value=0)
123
+
124
+ # Indices of all tokens
125
+ indices = torch.nonzero(attention_mask, as_tuple=False).flatten()
126
+
127
+ # Find where 0s occur (segment boundaries)
128
+ cumsum_seqlens_in_batch = torch.cumsum(padded_attention_mask, dim=0)[
129
+ padded_attention_mask == 0
130
+ ]
131
+
132
+ # Compute seqlens per segment
133
+ seqlens_in_batch = (
134
+ cumsum_seqlens_in_batch
135
+ - F.pad(cumsum_seqlens_in_batch, (1, 0), value=0)[:-1]
136
+ ).to(torch.int)
137
+
138
+ max_seqlen_in_batch = (
139
+ seqlens_in_batch.max().item() if seqlens_in_batch.numel() > 0 else 0
140
+ )
141
+ cu_seqlens = F.pad(cumsum_seqlens_in_batch, (1, 0)).to(torch.int)
142
+ else:
143
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
144
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
145
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
146
+ cu_seqlens = F.pad(
147
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
148
+ )
149
+
54
150
  return (
55
151
  indices,
56
152
  cu_seqlens,
@@ -609,7 +705,8 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
609
705
  self.pretrained_wte = None
610
706
 
611
707
  self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
612
- self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
708
+ if not self.exclude_position_ids:
709
+ self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
613
710
  if self.include_values:
614
711
  self.vte = nn.Embedding(config.value_vocab_size, self.embed_dim)
615
712
  self.concept_value_transformation_layer = ConceptValueTransformationLayer(
@@ -635,6 +732,14 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
635
732
  # Initialize weights and apply final processing
636
733
  self.post_init()
637
734
 
735
+ # We do need to update the pre-computed attention bias matrix if sample packing requires a larger context window
736
+ if self.config.sample_packing_max_positions > self.config.n_positions:
737
+ logger.info(
738
+ "Updated attn_bias to %s according to sample_packing_max_positions",
739
+ config.sample_packing_max_positions,
740
+ )
741
+ self.update_attn_bias(self.config.sample_packing_max_positions)
742
+
638
743
  def initialize_pretrained_embeddings(self):
639
744
  layers = [
640
745
  nn.Embedding(self.config.vocab_size, self.config.pretrained_embedding_dim),
@@ -677,7 +782,8 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
677
782
  self.wte = self.wte.to(self.first_device)
678
783
  if self.config.use_pretrained_embeddings:
679
784
  self.pretrained_wte = self.pretrained_wte.to(self.first_device)
680
- self.wpe = self.wpe.to(self.first_device)
785
+ if not self.exclude_position_ids:
786
+ self.wpe = self.wpe.to(self.first_device)
681
787
  if self.include_values:
682
788
  self.vte = self.vte.to(self.first_device)
683
789
  self.concept_value_transformation_layer = (
@@ -703,7 +809,8 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
703
809
  self.wte = self.wte.to("cpu")
704
810
  if self.config.use_pretrained_embeddings:
705
811
  self.pretrained_wte = self.pretrained_wte.to("cpu")
706
- self.wpe = self.wpe.to("cpu")
812
+ if not self.exclude_position_ids:
813
+ self.wpe = self.wpe.to("cpu")
707
814
  self.vte = self.vte.to("cpu")
708
815
  self.concept_value_transformation_layer = (
709
816
  self.concept_value_transformation_layer.to("cpu")
@@ -728,8 +835,12 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
728
835
  persistent=False,
729
836
  )
730
837
 
731
- def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]:
732
- return self.wpe
838
+ def get_position_embeddings(
839
+ self,
840
+ ) -> Optional[Union[nn.Embedding, Tuple[nn.Embedding]]]:
841
+ if not self.exclude_position_ids:
842
+ return self.wpe
843
+ return None
733
844
 
734
845
  def set_position_embeddings(self, new_embeddings: nn.Embedding):
735
846
  self.wpe = new_embeddings
@@ -758,8 +869,8 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
758
869
  def forward(
759
870
  self,
760
871
  input_ids: Optional[torch.LongTensor],
761
- value_indicators: Optional[torch.BoolTensor],
762
- values: Optional[torch.LongTensor],
872
+ value_indicators: Optional[torch.BoolTensor] = None,
873
+ values: Optional[torch.LongTensor] = None,
763
874
  past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
764
875
  attention_mask: Optional[torch.FloatTensor] = None,
765
876
  position_ids: Optional[torch.LongTensor] = None,
@@ -850,12 +961,19 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
850
961
  == "flash_attention_2"
851
962
  ):
852
963
  attention_mask = attention_mask.view(batch_size, -1)
853
- # We create a 3D attention mask from a 2D tensor mask.
854
- # Sizes are [batch_size, 1, 1, to_seq_length]
855
- # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
856
- # this attention mask is more simple than the triangular masking of causal attention
857
- # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
858
- attention_mask = attention_mask[:, None, None, :]
964
+
965
+ # If this is sample packing, we need to great the
966
+ if is_sample_pack(attention_mask):
967
+ attention_mask = create_sample_packing_attention_mask(
968
+ attention_mask
969
+ )[:, None, :, :]
970
+ else:
971
+ # We create a 3D attention mask from a 2D tensor mask.
972
+ # Sizes are [batch_size, 1, 1, to_seq_length]
973
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
974
+ # this attention mask is more simple than the triangular masking of causal attention
975
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
976
+ attention_mask = attention_mask[:, None, None, :]
859
977
 
860
978
  # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
861
979
  # masked positions, this operation will create a tensor which is 0.0 for
@@ -1288,9 +1406,26 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1288
1406
  if labels is not None:
1289
1407
  # move labels to correct device to enable model parallelism
1290
1408
  labels = labels.to(lm_logits.device)
1409
+
1410
+ if self.config.causal_sfm:
1411
+ # Ensure demographic_labels matches the dtype of original labels
1412
+ demographic_labels = torch.full(
1413
+ (labels.shape[0], self.config.demographics_size),
1414
+ -100,
1415
+ dtype=labels.dtype, # Match the original labels' dtype
1416
+ device=labels.device, # Ensure on the same device
1417
+ )
1418
+ # Concatenate the demographic labels with the rest of the original labels
1419
+ labels = torch.cat(
1420
+ (demographic_labels, labels[:, self.config.demographics_size :]),
1421
+ dim=1,
1422
+ )
1423
+
1291
1424
  # Shift so that tokens < n predict n
1292
1425
  shift_logits = lm_logits[..., :-1, :].contiguous()
1293
1426
  shift_labels = labels[..., 1:].contiguous()
1427
+ valid_tokens: torch.BoolTensor = shift_labels != 100
1428
+ total_num_tokens = valid_tokens.sum()
1294
1429
  if (
1295
1430
  self.cehrgpt.config.lab_token_penalty
1296
1431
  and self.cehrgpt.config.lab_token_exists
@@ -1310,24 +1445,30 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1310
1445
  lab_index,
1311
1446
  token_loss * self.cehrgpt.config.lab_token_loss_weight,
1312
1447
  token_loss,
1313
- ).mean()
1448
+ )
1449
+
1450
+ token_loss = token_loss.sum() / total_num_tokens
1314
1451
  else:
1315
1452
  # Flatten the tokens
1316
- loss_fct = CrossEntropyLoss()
1453
+ loss_fct = CrossEntropyLoss(reduction="none")
1317
1454
  token_loss = loss_fct(
1318
1455
  shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
1319
1456
  )
1320
- loss = token_loss
1457
+ token_loss = token_loss.sum() / total_num_tokens
1458
+
1459
+ loss = token_loss * self.cehrgpt.config.next_token_prediction_loss_weight
1321
1460
 
1322
1461
  if self.cehrgpt.config.entropy_penalty:
1323
1462
  # Compute probabilities using softmax
1324
- probs = torch.softmax(lm_logits, dim=-1)
1463
+ probs = torch.softmax(shift_logits, dim=-1)
1325
1464
  # Compute negative entropy: sum(p * log(p))
1326
1465
  entropy = torch.sum(
1327
1466
  probs * torch.log(probs + 1e-9), dim=-1
1328
1467
  ) # Add epsilon for numerical stability
1468
+ entropy = torch.where(valid_tokens, entropy, 0)
1329
1469
  # Regularization term: mean entropy scaled by alpha
1330
- loss += self.cehrgpt.config.entropy_penalty_alpha * entropy.mean()
1470
+ entropy_penalty = entropy.sum() / total_num_tokens
1471
+ loss += entropy_penalty * self.cehrgpt.config.entropy_penalty_alpha
1331
1472
 
1332
1473
  # We add another loss term when use_sub_time_tokenization is enabled, we need to recover the sub time token
1333
1474
  # predictions for year/month/token
@@ -1352,54 +1493,60 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1352
1493
  ),
1353
1494
  shifted_time_token_labels.view(-1),
1354
1495
  )
1355
-
1356
- time_token_loss = time_token_loss.view(
1357
- -1, 3
1358
- ) * shifted_time_token_indicators.view(-1, 1).to(hidden_states.dtype)
1359
- time_token_loss = time_token_loss.sum(-1)
1360
- time_token_loss = (
1361
- torch.mean(time_token_loss) * self.config.time_token_loss_weight
1496
+ time_token_loss = torch.where(
1497
+ shifted_time_token_indicators.view(-1, 1).to(torch.bool),
1498
+ time_token_loss.view(-1, 3),
1499
+ 0,
1362
1500
  )
1363
- loss += time_token_loss
1501
+ time_token_loss = time_token_loss.sum() / total_num_tokens
1502
+ loss += time_token_loss * self.config.time_token_loss_weight
1364
1503
 
1365
- if time_to_visits is not None:
1366
- # Get lambda and k parameters
1367
- lambda_param, k_param = self.tte_head(hidden_states)
1504
+ if time_to_visits is not None:
1505
+ # Get lambda and k parameters
1506
+ lambda_param, k_param = self.tte_head(hidden_states)
1368
1507
 
1369
- # Perform slicing before tensors are split across GPUs
1370
- shifted_lambda_param = lambda_param[..., :-1, :].contiguous()
1371
- shifted_k_param = k_param[..., :-1, :].contiguous()
1372
- shift_time_to_visits = time_to_visits[..., 1:].contiguous()
1508
+ # Perform slicing before tensors are split across GPUs
1509
+ shifted_lambda_param = lambda_param[..., :-1, :].contiguous()
1510
+ shifted_k_param = k_param[..., :-1, :].contiguous()
1511
+ shift_time_to_visits = time_to_visits[..., 1:].contiguous()
1373
1512
 
1374
- # Move to the same device as lambda_param
1375
- shift_time_to_visits = shift_time_to_visits.to(lambda_param.device)
1513
+ # Move to the same device as lambda_param
1514
+ shift_time_to_visits = shift_time_to_visits.to(lambda_param.device)
1376
1515
 
1377
- time_to_visit_indicator = (shift_time_to_visits >= 0).to(
1378
- hidden_states.dtype
1379
- )
1380
- # Define the Gamma distribution
1381
- dist = Gamma(shifted_k_param.squeeze(-1), shifted_lambda_param.squeeze(-1))
1382
- # Compute log-probs and apply the time_to_visit_indicator
1383
- log_probs = dist.log_prob(torch.clamp(shift_time_to_visits, min=0.0) + 1e-6)
1384
- log_probs *= time_to_visit_indicator
1385
- time_to_visit_loss = (
1386
- -log_probs.mean() * self.config.time_to_visit_loss_weight
1387
- )
1388
- # Compute the loss
1389
- loss += time_to_visit_loss
1390
-
1391
- if true_values is not None and true_value_indicators is not None:
1392
- true_values = true_values.to(value_logits.device)
1393
- shift_value_logits = value_logits[..., :-1, :].contiguous()
1394
- shift_value_indicators = true_value_indicators[..., :-1].contiguous()
1395
- shift_next_values = true_values[..., 1:].contiguous()
1396
- value_loss_fct = CrossEntropyLoss(reduce=False)
1397
- token_value_loss = value_loss_fct(
1398
- shift_value_logits.view(-1, shift_value_logits.size(-1)),
1399
- shift_next_values.view(-1),
1400
- )
1401
- token_value_loss *= shift_value_indicators.view(-1)
1402
- loss += token_value_loss.mean()
1516
+ time_to_visit_indicator = shift_time_to_visits >= 0
1517
+ # Define the Gamma distribution
1518
+ dist = Gamma(
1519
+ shifted_k_param.squeeze(-1), shifted_lambda_param.squeeze(-1)
1520
+ )
1521
+ # Compute log-probs and apply the time_to_visit_indicator
1522
+ log_probs = dist.log_prob(torch.clamp(shift_time_to_visits, min=1e-3))
1523
+ log_probs = torch.where(time_to_visit_indicator, log_probs, 0)
1524
+ time_to_visit_loss = -log_probs.sum() / total_num_tokens
1525
+ # Compute the loss
1526
+ loss += time_to_visit_loss * self.config.time_to_visit_loss_weight
1527
+
1528
+ if true_values is not None and true_value_indicators is not None:
1529
+ true_values = true_values.to(value_logits.device)
1530
+ shift_value_logits = value_logits[..., :-1, :].contiguous()
1531
+ shift_value_indicators = true_value_indicators[..., :-1].contiguous()
1532
+ shift_next_values = true_values[..., 1:].contiguous()
1533
+ value_loss_fct = CrossEntropyLoss(reduction="none")
1534
+ token_value_loss = value_loss_fct(
1535
+ shift_value_logits.view(-1, shift_value_logits.size(-1)),
1536
+ shift_next_values.view(-1),
1537
+ )
1538
+ token_value_loss = torch.where(
1539
+ shift_value_indicators.view(-1), token_value_loss, 0
1540
+ )
1541
+ token_value_loss = token_value_loss.sum() / total_num_tokens
1542
+ if (
1543
+ self.cehrgpt.config.lab_token_penalty
1544
+ and self.cehrgpt.config.lab_token_exists
1545
+ ):
1546
+ token_value_loss = (
1547
+ token_value_loss * self.config.lab_token_loss_weight
1548
+ )
1549
+ loss += token_value_loss * self.config.value_prediction_loss_weight
1403
1550
 
1404
1551
  if not return_dict:
1405
1552
  output = (lm_logits,) + transformer_outputs[1:]
@@ -1768,6 +1915,7 @@ class CehrGptForClassification(CEHRGPTPreTrainedModel):
1768
1915
  return_dict: Optional[bool] = None,
1769
1916
  **kwargs,
1770
1917
  ) -> CehrGptSequenceClassifierOutput:
1918
+
1771
1919
  cehrgpt_output = self.cehrgpt(
1772
1920
  input_ids=input_ids,
1773
1921
  value_indicators=value_indicators,
@@ -1782,17 +1930,39 @@ class CehrGptForClassification(CEHRGPTPreTrainedModel):
1782
1930
  return_dict=return_dict,
1783
1931
  )
1784
1932
 
1785
- # Disable autocasting for precision-sensitive operations
1786
- with torch.autocast(device_type="cuda", enabled=False):
1787
- normalized_age = self._apply_age_norm(age_at_index)
1933
+ if is_sample_pack(attention_mask):
1934
+ features = extract_features_from_packed_sequence(
1935
+ cehrgpt_output.last_hidden_state, attention_mask
1936
+ )
1937
+ assert features.shape[1] == classifier_label.shape[1], (
1938
+ "the length of the features need to be the same as the length of classifier_label. "
1939
+ f"features.shape[1]: {features.shape[1]}, "
1940
+ f"classifier_label.shape[1]: {classifier_label.shape[1]}"
1941
+ )
1942
+ assert features.shape[1] == age_at_index.shape[1], (
1943
+ "the length of the features need to be the same as the length of age_at_index. "
1944
+ f"features.shape[1]: {features.shape[1]}, "
1945
+ f"age_at_index.shape[1]: {age_at_index.shape[1]}"
1946
+ )
1947
+ num_samples = age_at_index.shape[1]
1948
+ features = features.view((num_samples, -1))
1949
+ classifier_label = classifier_label.view((num_samples, -1))
1950
+ with torch.autocast(device_type="cuda", enabled=False):
1951
+ normalized_age = self._apply_age_norm(
1952
+ age_at_index.view((num_samples, 1))
1953
+ )
1954
+ else:
1955
+ features = cehrgpt_output.last_hidden_state[..., -1, :]
1956
+ # Disable autocasting for precision-sensitive operations
1957
+ with torch.autocast(device_type="cuda", enabled=False):
1958
+ normalized_age = self._apply_age_norm(age_at_index)
1788
1959
 
1789
1960
  # In case the model is in bfloat16
1790
- if cehrgpt_output.last_hidden_state.dtype != normalized_age.dtype:
1791
- normalized_age = normalized_age.to(cehrgpt_output.last_hidden_state.dtype)
1961
+ if features.dtype != normalized_age.dtype:
1962
+ normalized_age = normalized_age.to(features.dtype)
1792
1963
 
1793
1964
  # In fine-tuning, the sequences are left-padded, so we use the last element as the pooler
1794
- output_pooler = cehrgpt_output.last_hidden_state[..., -1, :]
1795
- next_input = self.dropout(output_pooler)
1965
+ next_input = self.dropout(features)
1796
1966
  next_input = torch.cat([next_input, normalized_age], dim=1)
1797
1967
  next_input = self.dense_layer(next_input)
1798
1968
  next_input = nn.functional.relu(next_input)
@@ -25,6 +25,7 @@ from tokenizers.pre_tokenizers import WhitespaceSplit
25
25
  from tokenizers.trainers import WordLevelTrainer
26
26
  from tqdm import tqdm
27
27
  from transformers import PreTrainedTokenizer
28
+ from transformers.utils import logging
28
29
 
29
30
  from cehrgpt.gpt_utils import (
30
31
  convert_time_interval_to_time_tuple,
@@ -53,6 +54,7 @@ TOKEN_TO_SUB_TIME_TOKEN_MAPPING_FILE_NAME = "token_to_sub_time_token_mapping.jso
53
54
  LAB_STATS_FILE_NAME = "cehrgpt_lab_stats.pickle"
54
55
  LEGACY_LAB_STATS_FILE_NAME = "cehrgpt_lab_stats.json"
55
56
  CONCEPT_MAPPING_FILE_NAME = "concept_name_mapping.json"
57
+ LOG = logging.get_logger("transformers")
56
58
 
57
59
 
58
60
  def truncated_sample(sample, standard_deviation):
@@ -888,6 +890,7 @@ class CehrGptTokenizer(PreTrainedTokenizer):
888
890
  if isinstance(dataset, DatasetDict):
889
891
  dataset = dataset["train"]
890
892
 
893
+ LOG.info("Training the tokenizer for concepts")
891
894
  concept_tokenizer = cls.train_concept_tokenizer(
892
895
  dataset,
893
896
  feature_name="concept_ids",
@@ -900,6 +903,7 @@ class CehrGptTokenizer(PreTrainedTokenizer):
900
903
  if concept_value_column not in row:
901
904
  concept_value_column = "concept_values"
902
905
  break
906
+ LOG.info("Training the tokenizer for values")
903
907
  value_tokenizer = cls.train_concept_tokenizer(
904
908
  dataset,
905
909
  feature_name=concept_value_column,