cehrgpt 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/data/hf_cehrgpt_dataset.py +24 -4
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +260 -84
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +99 -88
- cehrgpt/data/sample_packing_sampler.py +151 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
- cehrgpt/models/config.py +10 -0
- cehrgpt/models/hf_cehrgpt.py +243 -73
- cehrgpt/models/tokenization_hf_cehrgpt.py +4 -0
- cehrgpt/runners/data_utils.py +243 -0
- cehrgpt/runners/gpt_runner_util.py +0 -10
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +152 -279
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +229 -105
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +42 -0
- cehrgpt/runners/hyperparameter_search_util.py +4 -1
- cehrgpt/runners/sample_packing_trainer.py +168 -0
- cehrgpt/simulations/generate_plots.py +95 -0
- cehrgpt/simulations/run_simulation.sh +24 -0
- cehrgpt/simulations/time_embedding_simulation.py +250 -0
- cehrgpt/simulations/time_token_simulation.py +177 -0
- cehrgpt/tools/linear_prob/__init__.py +0 -0
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +467 -0
- cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/METADATA +7 -5
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/RECORD +28 -26
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/WHEEL +1 -1
- cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
- cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
- cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
- cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
- cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
- cehrgpt/rl_finetune/ppo_finetune.py +0 -394
- cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
- cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
- /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info/licenses}/LICENSE +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/top_level.txt +0 -0
cehrgpt/models/hf_cehrgpt.py
CHANGED
@@ -6,7 +6,7 @@ import numpy as np
|
|
6
6
|
import torch
|
7
7
|
import torch.nn.functional as f
|
8
8
|
from torch import nn
|
9
|
-
from torch.distributions import Gamma
|
9
|
+
from torch.distributions import Gamma, Weibull
|
10
10
|
from torch.nn import CrossEntropyLoss
|
11
11
|
from torch.nn import functional as F
|
12
12
|
from transformers import PreTrainedModel
|
@@ -45,12 +45,108 @@ if is_accelerate_available():
|
|
45
45
|
logger = logging.get_logger(__name__)
|
46
46
|
|
47
47
|
|
48
|
+
def extract_features_from_packed_sequence(
|
49
|
+
hidden_state: torch.Tensor,
|
50
|
+
attention_mask: torch.Tensor,
|
51
|
+
) -> torch.Tensor:
|
52
|
+
max_index = attention_mask.nonzero(as_tuple=False).flatten()[-1]
|
53
|
+
padded_attention_mask = F.pad(attention_mask[:, : max_index + 1], (0, 1))
|
54
|
+
feature_indices = torch.nonzero(padded_attention_mask == 0)[:, 1] - 1
|
55
|
+
return hidden_state[:, feature_indices]
|
56
|
+
|
57
|
+
|
58
|
+
def create_sample_packing_attention_mask(attention_mask: torch.Tensor) -> torch.Tensor:
|
59
|
+
"""
|
60
|
+
Create a block-diagonal attention mask for packed sequences within a batch.
|
61
|
+
|
62
|
+
Args:
|
63
|
+
attention_mask (torch.Tensor): (batch_size, seq_len) binary mask where 1 = token, 0 = padding
|
64
|
+
|
65
|
+
Returns:
|
66
|
+
torch.Tensor: (batch_size, seq_len, seq_len) attention mask where entries are 1 if tokens
|
67
|
+
can attend to each other (within same packed segment), 0 otherwise.
|
68
|
+
"""
|
69
|
+
# Step 1: Identify segments within each sample
|
70
|
+
cumsum_mask = (attention_mask == 0).cumsum(dim=-1)
|
71
|
+
segment_ids = cumsum_mask * attention_mask # zeros remain zero
|
72
|
+
|
73
|
+
# Step 2: Compare segment IDs pairwise per batch element
|
74
|
+
# Shape: (batch_size, seq_len, seq_len)
|
75
|
+
attn_matrix = (segment_ids.unsqueeze(2) == segment_ids.unsqueeze(1)).int()
|
76
|
+
|
77
|
+
# Step 3: Mask out padding tokens
|
78
|
+
mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
|
79
|
+
attn_matrix = attn_matrix * mask
|
80
|
+
|
81
|
+
return attn_matrix
|
82
|
+
|
83
|
+
|
84
|
+
def is_sample_pack(attention_mask: torch.Tensor) -> bool:
|
85
|
+
"""
|
86
|
+
Determines whether any sequence in the batch is likely sample-packed.
|
87
|
+
|
88
|
+
A sample-packed sequence is one where there are non-padding (1) tokens
|
89
|
+
after a padding (0) token, indicating multiple sequences packed together
|
90
|
+
with padding as a separator.
|
91
|
+
|
92
|
+
Args:
|
93
|
+
attention_mask (torch.Tensor): A tensor of shape (batch_size, seq_len)
|
94
|
+
where 1 indicates a real token and 0 indicates padding.
|
95
|
+
|
96
|
+
Returns:
|
97
|
+
bool: True if any sample in the batch is sample-packed, False otherwise.
|
98
|
+
"""
|
99
|
+
|
100
|
+
# If the attention_maks is left padded, we will flip it so we can use the same logic below
|
101
|
+
if (attention_mask[:, 0] == 0).any():
|
102
|
+
attention_mask = attention_mask.flip(dims=[1])
|
103
|
+
|
104
|
+
nonzero_counts = attention_mask.sum(dim=1)
|
105
|
+
max_token_positions = torch.argmax(attention_mask.flip(dims=[1]), dim=1)
|
106
|
+
max_indices = attention_mask.shape[1] - 1 - max_token_positions
|
107
|
+
return torch.any(nonzero_counts < (max_indices + 1)).item()
|
108
|
+
|
109
|
+
|
48
110
|
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
49
111
|
def _get_unpad_data(attention_mask):
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
112
|
+
# This infers sample packing
|
113
|
+
if is_sample_pack(attention_mask):
|
114
|
+
# Assume input: attention_mask shape = (batch, seq_len)
|
115
|
+
attention_mask = attention_mask.flatten() # shape: (seq_len,)
|
116
|
+
|
117
|
+
# Compute max_index of the last non-zero element
|
118
|
+
nonzero = torch.nonzero(attention_mask, as_tuple=False).flatten()
|
119
|
+
max_index = nonzero[-1].item()
|
120
|
+
|
121
|
+
# Pad the truncated attention mask
|
122
|
+
padded_attention_mask = F.pad(attention_mask[: max_index + 1], (0, 1), value=0)
|
123
|
+
|
124
|
+
# Indices of all tokens
|
125
|
+
indices = torch.nonzero(attention_mask, as_tuple=False).flatten()
|
126
|
+
|
127
|
+
# Find where 0s occur (segment boundaries)
|
128
|
+
cumsum_seqlens_in_batch = torch.cumsum(padded_attention_mask, dim=0)[
|
129
|
+
padded_attention_mask == 0
|
130
|
+
]
|
131
|
+
|
132
|
+
# Compute seqlens per segment
|
133
|
+
seqlens_in_batch = (
|
134
|
+
cumsum_seqlens_in_batch
|
135
|
+
- F.pad(cumsum_seqlens_in_batch, (1, 0), value=0)[:-1]
|
136
|
+
).to(torch.int)
|
137
|
+
|
138
|
+
max_seqlen_in_batch = (
|
139
|
+
seqlens_in_batch.max().item() if seqlens_in_batch.numel() > 0 else 0
|
140
|
+
)
|
141
|
+
cu_seqlens = F.pad(cumsum_seqlens_in_batch, (1, 0)).to(torch.int)
|
142
|
+
else:
|
143
|
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
144
|
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
145
|
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
146
|
+
cu_seqlens = F.pad(
|
147
|
+
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
148
|
+
)
|
149
|
+
|
54
150
|
return (
|
55
151
|
indices,
|
56
152
|
cu_seqlens,
|
@@ -609,7 +705,8 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
609
705
|
self.pretrained_wte = None
|
610
706
|
|
611
707
|
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
|
612
|
-
|
708
|
+
if not self.exclude_position_ids:
|
709
|
+
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
613
710
|
if self.include_values:
|
614
711
|
self.vte = nn.Embedding(config.value_vocab_size, self.embed_dim)
|
615
712
|
self.concept_value_transformation_layer = ConceptValueTransformationLayer(
|
@@ -635,6 +732,14 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
635
732
|
# Initialize weights and apply final processing
|
636
733
|
self.post_init()
|
637
734
|
|
735
|
+
# We do need to update the pre-computed attention bias matrix if sample packing requires a larger context window
|
736
|
+
if self.config.sample_packing_max_positions > self.config.n_positions:
|
737
|
+
logger.info(
|
738
|
+
"Updated attn_bias to %s according to sample_packing_max_positions",
|
739
|
+
config.sample_packing_max_positions,
|
740
|
+
)
|
741
|
+
self.update_attn_bias(self.config.sample_packing_max_positions)
|
742
|
+
|
638
743
|
def initialize_pretrained_embeddings(self):
|
639
744
|
layers = [
|
640
745
|
nn.Embedding(self.config.vocab_size, self.config.pretrained_embedding_dim),
|
@@ -677,7 +782,8 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
677
782
|
self.wte = self.wte.to(self.first_device)
|
678
783
|
if self.config.use_pretrained_embeddings:
|
679
784
|
self.pretrained_wte = self.pretrained_wte.to(self.first_device)
|
680
|
-
|
785
|
+
if not self.exclude_position_ids:
|
786
|
+
self.wpe = self.wpe.to(self.first_device)
|
681
787
|
if self.include_values:
|
682
788
|
self.vte = self.vte.to(self.first_device)
|
683
789
|
self.concept_value_transformation_layer = (
|
@@ -703,7 +809,8 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
703
809
|
self.wte = self.wte.to("cpu")
|
704
810
|
if self.config.use_pretrained_embeddings:
|
705
811
|
self.pretrained_wte = self.pretrained_wte.to("cpu")
|
706
|
-
|
812
|
+
if not self.exclude_position_ids:
|
813
|
+
self.wpe = self.wpe.to("cpu")
|
707
814
|
self.vte = self.vte.to("cpu")
|
708
815
|
self.concept_value_transformation_layer = (
|
709
816
|
self.concept_value_transformation_layer.to("cpu")
|
@@ -728,8 +835,12 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
728
835
|
persistent=False,
|
729
836
|
)
|
730
837
|
|
731
|
-
def get_position_embeddings(
|
732
|
-
|
838
|
+
def get_position_embeddings(
|
839
|
+
self,
|
840
|
+
) -> Optional[Union[nn.Embedding, Tuple[nn.Embedding]]]:
|
841
|
+
if not self.exclude_position_ids:
|
842
|
+
return self.wpe
|
843
|
+
return None
|
733
844
|
|
734
845
|
def set_position_embeddings(self, new_embeddings: nn.Embedding):
|
735
846
|
self.wpe = new_embeddings
|
@@ -758,8 +869,8 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
758
869
|
def forward(
|
759
870
|
self,
|
760
871
|
input_ids: Optional[torch.LongTensor],
|
761
|
-
value_indicators: Optional[torch.BoolTensor],
|
762
|
-
values: Optional[torch.LongTensor],
|
872
|
+
value_indicators: Optional[torch.BoolTensor] = None,
|
873
|
+
values: Optional[torch.LongTensor] = None,
|
763
874
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
764
875
|
attention_mask: Optional[torch.FloatTensor] = None,
|
765
876
|
position_ids: Optional[torch.LongTensor] = None,
|
@@ -850,12 +961,19 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
850
961
|
== "flash_attention_2"
|
851
962
|
):
|
852
963
|
attention_mask = attention_mask.view(batch_size, -1)
|
853
|
-
|
854
|
-
#
|
855
|
-
|
856
|
-
|
857
|
-
|
858
|
-
|
964
|
+
|
965
|
+
# If this is sample packing, we need to great the
|
966
|
+
if is_sample_pack(attention_mask):
|
967
|
+
attention_mask = create_sample_packing_attention_mask(
|
968
|
+
attention_mask
|
969
|
+
)[:, None, :, :]
|
970
|
+
else:
|
971
|
+
# We create a 3D attention mask from a 2D tensor mask.
|
972
|
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
973
|
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
974
|
+
# this attention mask is more simple than the triangular masking of causal attention
|
975
|
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
976
|
+
attention_mask = attention_mask[:, None, None, :]
|
859
977
|
|
860
978
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
861
979
|
# masked positions, this operation will create a tensor which is 0.0 for
|
@@ -1288,9 +1406,26 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1288
1406
|
if labels is not None:
|
1289
1407
|
# move labels to correct device to enable model parallelism
|
1290
1408
|
labels = labels.to(lm_logits.device)
|
1409
|
+
|
1410
|
+
if self.config.causal_sfm:
|
1411
|
+
# Ensure demographic_labels matches the dtype of original labels
|
1412
|
+
demographic_labels = torch.full(
|
1413
|
+
(labels.shape[0], self.config.demographics_size),
|
1414
|
+
-100,
|
1415
|
+
dtype=labels.dtype, # Match the original labels' dtype
|
1416
|
+
device=labels.device, # Ensure on the same device
|
1417
|
+
)
|
1418
|
+
# Concatenate the demographic labels with the rest of the original labels
|
1419
|
+
labels = torch.cat(
|
1420
|
+
(demographic_labels, labels[:, self.config.demographics_size :]),
|
1421
|
+
dim=1,
|
1422
|
+
)
|
1423
|
+
|
1291
1424
|
# Shift so that tokens < n predict n
|
1292
1425
|
shift_logits = lm_logits[..., :-1, :].contiguous()
|
1293
1426
|
shift_labels = labels[..., 1:].contiguous()
|
1427
|
+
valid_tokens: torch.BoolTensor = shift_labels != 100
|
1428
|
+
total_num_tokens = valid_tokens.sum()
|
1294
1429
|
if (
|
1295
1430
|
self.cehrgpt.config.lab_token_penalty
|
1296
1431
|
and self.cehrgpt.config.lab_token_exists
|
@@ -1310,24 +1445,30 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1310
1445
|
lab_index,
|
1311
1446
|
token_loss * self.cehrgpt.config.lab_token_loss_weight,
|
1312
1447
|
token_loss,
|
1313
|
-
)
|
1448
|
+
)
|
1449
|
+
|
1450
|
+
token_loss = token_loss.sum() / total_num_tokens
|
1314
1451
|
else:
|
1315
1452
|
# Flatten the tokens
|
1316
|
-
loss_fct = CrossEntropyLoss()
|
1453
|
+
loss_fct = CrossEntropyLoss(reduction="none")
|
1317
1454
|
token_loss = loss_fct(
|
1318
1455
|
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
|
1319
1456
|
)
|
1320
|
-
|
1457
|
+
token_loss = token_loss.sum() / total_num_tokens
|
1458
|
+
|
1459
|
+
loss = token_loss * self.cehrgpt.config.next_token_prediction_loss_weight
|
1321
1460
|
|
1322
1461
|
if self.cehrgpt.config.entropy_penalty:
|
1323
1462
|
# Compute probabilities using softmax
|
1324
|
-
probs = torch.softmax(
|
1463
|
+
probs = torch.softmax(shift_logits, dim=-1)
|
1325
1464
|
# Compute negative entropy: sum(p * log(p))
|
1326
1465
|
entropy = torch.sum(
|
1327
1466
|
probs * torch.log(probs + 1e-9), dim=-1
|
1328
1467
|
) # Add epsilon for numerical stability
|
1468
|
+
entropy = torch.where(valid_tokens, entropy, 0)
|
1329
1469
|
# Regularization term: mean entropy scaled by alpha
|
1330
|
-
|
1470
|
+
entropy_penalty = entropy.sum() / total_num_tokens
|
1471
|
+
loss += entropy_penalty * self.cehrgpt.config.entropy_penalty_alpha
|
1331
1472
|
|
1332
1473
|
# We add another loss term when use_sub_time_tokenization is enabled, we need to recover the sub time token
|
1333
1474
|
# predictions for year/month/token
|
@@ -1352,54 +1493,60 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1352
1493
|
),
|
1353
1494
|
shifted_time_token_labels.view(-1),
|
1354
1495
|
)
|
1355
|
-
|
1356
|
-
|
1357
|
-
-1, 3
|
1358
|
-
|
1359
|
-
time_token_loss = time_token_loss.sum(-1)
|
1360
|
-
time_token_loss = (
|
1361
|
-
torch.mean(time_token_loss) * self.config.time_token_loss_weight
|
1496
|
+
time_token_loss = torch.where(
|
1497
|
+
shifted_time_token_indicators.view(-1, 1).to(torch.bool),
|
1498
|
+
time_token_loss.view(-1, 3),
|
1499
|
+
0,
|
1362
1500
|
)
|
1363
|
-
|
1501
|
+
time_token_loss = time_token_loss.sum() / total_num_tokens
|
1502
|
+
loss += time_token_loss * self.config.time_token_loss_weight
|
1364
1503
|
|
1365
|
-
|
1366
|
-
|
1367
|
-
|
1504
|
+
if time_to_visits is not None:
|
1505
|
+
# Get lambda and k parameters
|
1506
|
+
lambda_param, k_param = self.tte_head(hidden_states)
|
1368
1507
|
|
1369
|
-
|
1370
|
-
|
1371
|
-
|
1372
|
-
|
1508
|
+
# Perform slicing before tensors are split across GPUs
|
1509
|
+
shifted_lambda_param = lambda_param[..., :-1, :].contiguous()
|
1510
|
+
shifted_k_param = k_param[..., :-1, :].contiguous()
|
1511
|
+
shift_time_to_visits = time_to_visits[..., 1:].contiguous()
|
1373
1512
|
|
1374
|
-
|
1375
|
-
|
1513
|
+
# Move to the same device as lambda_param
|
1514
|
+
shift_time_to_visits = shift_time_to_visits.to(lambda_param.device)
|
1376
1515
|
|
1377
|
-
|
1378
|
-
|
1379
|
-
|
1380
|
-
|
1381
|
-
|
1382
|
-
|
1383
|
-
|
1384
|
-
|
1385
|
-
|
1386
|
-
|
1387
|
-
|
1388
|
-
|
1389
|
-
|
1390
|
-
|
1391
|
-
|
1392
|
-
|
1393
|
-
|
1394
|
-
|
1395
|
-
|
1396
|
-
|
1397
|
-
|
1398
|
-
|
1399
|
-
|
1400
|
-
|
1401
|
-
|
1402
|
-
|
1516
|
+
time_to_visit_indicator = shift_time_to_visits >= 0
|
1517
|
+
# Define the Gamma distribution
|
1518
|
+
dist = Gamma(
|
1519
|
+
shifted_k_param.squeeze(-1), shifted_lambda_param.squeeze(-1)
|
1520
|
+
)
|
1521
|
+
# Compute log-probs and apply the time_to_visit_indicator
|
1522
|
+
log_probs = dist.log_prob(torch.clamp(shift_time_to_visits, min=1e-3))
|
1523
|
+
log_probs = torch.where(time_to_visit_indicator, log_probs, 0)
|
1524
|
+
time_to_visit_loss = -log_probs.sum() / total_num_tokens
|
1525
|
+
# Compute the loss
|
1526
|
+
loss += time_to_visit_loss * self.config.time_to_visit_loss_weight
|
1527
|
+
|
1528
|
+
if true_values is not None and true_value_indicators is not None:
|
1529
|
+
true_values = true_values.to(value_logits.device)
|
1530
|
+
shift_value_logits = value_logits[..., :-1, :].contiguous()
|
1531
|
+
shift_value_indicators = true_value_indicators[..., :-1].contiguous()
|
1532
|
+
shift_next_values = true_values[..., 1:].contiguous()
|
1533
|
+
value_loss_fct = CrossEntropyLoss(reduction="none")
|
1534
|
+
token_value_loss = value_loss_fct(
|
1535
|
+
shift_value_logits.view(-1, shift_value_logits.size(-1)),
|
1536
|
+
shift_next_values.view(-1),
|
1537
|
+
)
|
1538
|
+
token_value_loss = torch.where(
|
1539
|
+
shift_value_indicators.view(-1), token_value_loss, 0
|
1540
|
+
)
|
1541
|
+
token_value_loss = token_value_loss.sum() / total_num_tokens
|
1542
|
+
if (
|
1543
|
+
self.cehrgpt.config.lab_token_penalty
|
1544
|
+
and self.cehrgpt.config.lab_token_exists
|
1545
|
+
):
|
1546
|
+
token_value_loss = (
|
1547
|
+
token_value_loss * self.config.lab_token_loss_weight
|
1548
|
+
)
|
1549
|
+
loss += token_value_loss * self.config.value_prediction_loss_weight
|
1403
1550
|
|
1404
1551
|
if not return_dict:
|
1405
1552
|
output = (lm_logits,) + transformer_outputs[1:]
|
@@ -1768,6 +1915,7 @@ class CehrGptForClassification(CEHRGPTPreTrainedModel):
|
|
1768
1915
|
return_dict: Optional[bool] = None,
|
1769
1916
|
**kwargs,
|
1770
1917
|
) -> CehrGptSequenceClassifierOutput:
|
1918
|
+
|
1771
1919
|
cehrgpt_output = self.cehrgpt(
|
1772
1920
|
input_ids=input_ids,
|
1773
1921
|
value_indicators=value_indicators,
|
@@ -1782,17 +1930,39 @@ class CehrGptForClassification(CEHRGPTPreTrainedModel):
|
|
1782
1930
|
return_dict=return_dict,
|
1783
1931
|
)
|
1784
1932
|
|
1785
|
-
|
1786
|
-
|
1787
|
-
|
1933
|
+
if is_sample_pack(attention_mask):
|
1934
|
+
features = extract_features_from_packed_sequence(
|
1935
|
+
cehrgpt_output.last_hidden_state, attention_mask
|
1936
|
+
)
|
1937
|
+
assert features.shape[1] == classifier_label.shape[1], (
|
1938
|
+
"the length of the features need to be the same as the length of classifier_label. "
|
1939
|
+
f"features.shape[1]: {features.shape[1]}, "
|
1940
|
+
f"classifier_label.shape[1]: {classifier_label.shape[1]}"
|
1941
|
+
)
|
1942
|
+
assert features.shape[1] == age_at_index.shape[1], (
|
1943
|
+
"the length of the features need to be the same as the length of age_at_index. "
|
1944
|
+
f"features.shape[1]: {features.shape[1]}, "
|
1945
|
+
f"age_at_index.shape[1]: {age_at_index.shape[1]}"
|
1946
|
+
)
|
1947
|
+
num_samples = age_at_index.shape[1]
|
1948
|
+
features = features.view((num_samples, -1))
|
1949
|
+
classifier_label = classifier_label.view((num_samples, -1))
|
1950
|
+
with torch.autocast(device_type="cuda", enabled=False):
|
1951
|
+
normalized_age = self._apply_age_norm(
|
1952
|
+
age_at_index.view((num_samples, 1))
|
1953
|
+
)
|
1954
|
+
else:
|
1955
|
+
features = cehrgpt_output.last_hidden_state[..., -1, :]
|
1956
|
+
# Disable autocasting for precision-sensitive operations
|
1957
|
+
with torch.autocast(device_type="cuda", enabled=False):
|
1958
|
+
normalized_age = self._apply_age_norm(age_at_index)
|
1788
1959
|
|
1789
1960
|
# In case the model is in bfloat16
|
1790
|
-
if
|
1791
|
-
normalized_age = normalized_age.to(
|
1961
|
+
if features.dtype != normalized_age.dtype:
|
1962
|
+
normalized_age = normalized_age.to(features.dtype)
|
1792
1963
|
|
1793
1964
|
# In fine-tuning, the sequences are left-padded, so we use the last element as the pooler
|
1794
|
-
|
1795
|
-
next_input = self.dropout(output_pooler)
|
1965
|
+
next_input = self.dropout(features)
|
1796
1966
|
next_input = torch.cat([next_input, normalized_age], dim=1)
|
1797
1967
|
next_input = self.dense_layer(next_input)
|
1798
1968
|
next_input = nn.functional.relu(next_input)
|
@@ -25,6 +25,7 @@ from tokenizers.pre_tokenizers import WhitespaceSplit
|
|
25
25
|
from tokenizers.trainers import WordLevelTrainer
|
26
26
|
from tqdm import tqdm
|
27
27
|
from transformers import PreTrainedTokenizer
|
28
|
+
from transformers.utils import logging
|
28
29
|
|
29
30
|
from cehrgpt.gpt_utils import (
|
30
31
|
convert_time_interval_to_time_tuple,
|
@@ -53,6 +54,7 @@ TOKEN_TO_SUB_TIME_TOKEN_MAPPING_FILE_NAME = "token_to_sub_time_token_mapping.jso
|
|
53
54
|
LAB_STATS_FILE_NAME = "cehrgpt_lab_stats.pickle"
|
54
55
|
LEGACY_LAB_STATS_FILE_NAME = "cehrgpt_lab_stats.json"
|
55
56
|
CONCEPT_MAPPING_FILE_NAME = "concept_name_mapping.json"
|
57
|
+
LOG = logging.get_logger("transformers")
|
56
58
|
|
57
59
|
|
58
60
|
def truncated_sample(sample, standard_deviation):
|
@@ -888,6 +890,7 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
888
890
|
if isinstance(dataset, DatasetDict):
|
889
891
|
dataset = dataset["train"]
|
890
892
|
|
893
|
+
LOG.info("Training the tokenizer for concepts")
|
891
894
|
concept_tokenizer = cls.train_concept_tokenizer(
|
892
895
|
dataset,
|
893
896
|
feature_name="concept_ids",
|
@@ -900,6 +903,7 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
900
903
|
if concept_value_column not in row:
|
901
904
|
concept_value_column = "concept_values"
|
902
905
|
break
|
906
|
+
LOG.info("Training the tokenizer for values")
|
903
907
|
value_tokenizer = cls.train_concept_tokenizer(
|
904
908
|
dataset,
|
905
909
|
feature_name=concept_value_column,
|