cehrgpt 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/data/hf_cehrgpt_dataset.py +24 -4
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +260 -84
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +99 -88
- cehrgpt/data/sample_packing_sampler.py +151 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
- cehrgpt/models/config.py +10 -0
- cehrgpt/models/hf_cehrgpt.py +243 -73
- cehrgpt/models/tokenization_hf_cehrgpt.py +4 -0
- cehrgpt/runners/data_utils.py +243 -0
- cehrgpt/runners/gpt_runner_util.py +0 -10
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +152 -279
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +229 -105
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +42 -0
- cehrgpt/runners/hyperparameter_search_util.py +4 -1
- cehrgpt/runners/sample_packing_trainer.py +168 -0
- cehrgpt/simulations/generate_plots.py +95 -0
- cehrgpt/simulations/run_simulation.sh +24 -0
- cehrgpt/simulations/time_embedding_simulation.py +250 -0
- cehrgpt/simulations/time_token_simulation.py +177 -0
- cehrgpt/tools/linear_prob/__init__.py +0 -0
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +467 -0
- cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/METADATA +7 -5
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/RECORD +28 -26
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/WHEEL +1 -1
- cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
- cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
- cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
- cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
- cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
- cehrgpt/rl_finetune/ppo_finetune.py +0 -394
- cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
- cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
- /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info/licenses}/LICENSE +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,177 @@
|
|
1
|
+
from typing import Optional, Tuple
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
import torch
|
5
|
+
import torch.optim as optim
|
6
|
+
from sklearn.metrics import accuracy_score, roc_auc_score
|
7
|
+
from torch.nn import CrossEntropyLoss
|
8
|
+
from transformers import BertConfig, BertModel
|
9
|
+
|
10
|
+
from cehrgpt.simulations.time_embedding_simulation import generate_simulation_data
|
11
|
+
|
12
|
+
|
13
|
+
class ModelTimeToken(torch.nn.Module):
|
14
|
+
def __init__(self, vocab_size: int):
|
15
|
+
super(ModelTimeToken, self).__init__()
|
16
|
+
self.embedding = torch.nn.Embedding(vocab_size, 16)
|
17
|
+
self.bert = BertModel(
|
18
|
+
BertConfig(
|
19
|
+
vocab_size=vocab_size,
|
20
|
+
hidden_size=16,
|
21
|
+
num_attention_heads=2,
|
22
|
+
num_hidden_layers=2,
|
23
|
+
intermediate_size=32,
|
24
|
+
hidden_dropout_prob=0.0,
|
25
|
+
attention_probs_dropout_prob=0.0,
|
26
|
+
max_position_embeddings=3,
|
27
|
+
),
|
28
|
+
add_pooling_layer=False,
|
29
|
+
)
|
30
|
+
self.linear = torch.nn.Linear(48, 2)
|
31
|
+
|
32
|
+
def forward(
|
33
|
+
self,
|
34
|
+
input_ids: torch.LongTensor,
|
35
|
+
labels: Optional[torch.LongTensor] = None,
|
36
|
+
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
37
|
+
bz = input_ids.shape[0]
|
38
|
+
x = self.embedding(input_ids)
|
39
|
+
bert_output = self.bert.forward(inputs_embeds=x, return_dict=True)
|
40
|
+
output = bert_output.last_hidden_state.reshape((bz, 48))
|
41
|
+
y = self.linear(output)
|
42
|
+
loss = None
|
43
|
+
if labels is not None:
|
44
|
+
loss_fct = CrossEntropyLoss()
|
45
|
+
loss = loss_fct(y, labels)
|
46
|
+
return loss, y
|
47
|
+
|
48
|
+
|
49
|
+
def create_time_token_tokenizer(simulated_data):
|
50
|
+
vocab = []
|
51
|
+
for row in simulated_data:
|
52
|
+
x1, x2, t1, t2, y = row
|
53
|
+
x1 = f"c-{x1}"
|
54
|
+
x2 = f"c-{x2}"
|
55
|
+
t = f"t-{t2 - t1}"
|
56
|
+
if x1 not in vocab:
|
57
|
+
vocab.append(x1)
|
58
|
+
if x2 not in vocab:
|
59
|
+
vocab.append(x2)
|
60
|
+
if t not in vocab:
|
61
|
+
vocab.append(t)
|
62
|
+
return {c: i + 1 for i, c in enumerate(vocab)}
|
63
|
+
|
64
|
+
|
65
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
66
|
+
|
67
|
+
|
68
|
+
def eval_step(simulated_data, time_token_tokenizer, time_embedding_model):
|
69
|
+
time_embedding_model.eval()
|
70
|
+
eval_input_ids = []
|
71
|
+
eval_y = []
|
72
|
+
for row in simulated_data:
|
73
|
+
x1, x2, t1, t2, y = row
|
74
|
+
x1 = f"c-{x1}"
|
75
|
+
x2 = f"c-{x2}"
|
76
|
+
t = f"t-{t2 - t1}"
|
77
|
+
eval_input_ids.append(
|
78
|
+
[
|
79
|
+
time_token_tokenizer[x1],
|
80
|
+
time_token_tokenizer[t],
|
81
|
+
time_token_tokenizer[x2],
|
82
|
+
]
|
83
|
+
)
|
84
|
+
eval_y.append(y)
|
85
|
+
with torch.no_grad():
|
86
|
+
batched_input_ids = torch.tensor(eval_input_ids, dtype=torch.long).to(device)
|
87
|
+
batched_y = np.asarray(eval_y)
|
88
|
+
# Compute loss and forward pass
|
89
|
+
_, y_pred = time_embedding_model(batched_input_ids)
|
90
|
+
y_probs = torch.nn.functional.softmax(y_pred, dim=1)
|
91
|
+
y_probs = y_probs.detach().cpu().numpy()
|
92
|
+
roc_auc = roc_auc_score(batched_y, y_probs[:, 1])
|
93
|
+
accuracy = accuracy_score(batched_y, y_probs[:, 1] > y_probs[:, 0])
|
94
|
+
print(f"ROC AUC: {roc_auc}")
|
95
|
+
print(f"Accuracy: {accuracy}")
|
96
|
+
return accuracy, roc_auc
|
97
|
+
|
98
|
+
|
99
|
+
def train_step(
|
100
|
+
simulated_data, time_token_tokenizer, time_embedding_model, time_embedding_optimizer
|
101
|
+
):
|
102
|
+
batched_input_ids = []
|
103
|
+
batched_y = []
|
104
|
+
indices = np.random.choice(simulated_data.shape[0], size=8, replace=False)
|
105
|
+
for row in simulated_data[indices, :]:
|
106
|
+
x1, x2, t1, t2, y = row
|
107
|
+
x1 = f"c-{x1}"
|
108
|
+
x2 = f"c-{x2}"
|
109
|
+
t = f"t-{t2 - t1}"
|
110
|
+
batched_input_ids.append(
|
111
|
+
[
|
112
|
+
time_token_tokenizer[x1],
|
113
|
+
time_token_tokenizer[t],
|
114
|
+
time_token_tokenizer[x2],
|
115
|
+
]
|
116
|
+
)
|
117
|
+
batched_y.append(y)
|
118
|
+
batched_input_ids = torch.tensor(batched_input_ids, dtype=torch.long).to(device)
|
119
|
+
batched_y = torch.tensor(batched_y, dtype=torch.long).to(device)
|
120
|
+
# Zero the gradients
|
121
|
+
time_embedding_optimizer.zero_grad()
|
122
|
+
# Compute loss and forward pass
|
123
|
+
loss, _ = time_embedding_model(batched_input_ids, batched_y)
|
124
|
+
# Backward pass (compute gradients)
|
125
|
+
loss.backward()
|
126
|
+
# Update model parameters
|
127
|
+
time_embedding_optimizer.step()
|
128
|
+
return loss
|
129
|
+
|
130
|
+
|
131
|
+
def main(args):
|
132
|
+
simulated_data = generate_simulation_data(args.n_samples)
|
133
|
+
time_token_tokenizer = create_time_token_tokenizer(simulated_data)
|
134
|
+
time_embedding_model = ModelTimeToken(len(time_token_tokenizer) + 1).to(device)
|
135
|
+
time_embedding_optimizer = optim.Adam(time_embedding_model.parameters(), lr=0.001)
|
136
|
+
steps = []
|
137
|
+
roc_aucs = []
|
138
|
+
accuracies = []
|
139
|
+
for step in range(args.n_steps):
|
140
|
+
loss = train_step(
|
141
|
+
simulated_data,
|
142
|
+
time_token_tokenizer,
|
143
|
+
time_embedding_model,
|
144
|
+
time_embedding_optimizer,
|
145
|
+
)
|
146
|
+
print(f"Step {step}: Loss = {loss.item()}")
|
147
|
+
# Evaluation
|
148
|
+
if (
|
149
|
+
args.n_steps % args.eval_frequency == 0
|
150
|
+
and args.n_steps > args.eval_frequency
|
151
|
+
):
|
152
|
+
accuracy, roc_auc = eval_step(
|
153
|
+
simulated_data, time_token_tokenizer, time_embedding_model
|
154
|
+
)
|
155
|
+
steps.append(step)
|
156
|
+
roc_aucs.append(roc_auc)
|
157
|
+
accuracies.append(accuracy)
|
158
|
+
return {"steps": steps, "roc_auc": roc_aucs, "accuracy": accuracies}
|
159
|
+
|
160
|
+
|
161
|
+
if __name__ == "__main__":
|
162
|
+
import argparse
|
163
|
+
import json
|
164
|
+
from pathlib import Path
|
165
|
+
|
166
|
+
parser = argparse.ArgumentParser("Model with time token simulation")
|
167
|
+
parser.add_argument("--output_dir", type=str, required=True)
|
168
|
+
parser.add_argument("--n_steps", type=int, default=10000)
|
169
|
+
parser.add_argument("--n_samples", type=int, default=1000)
|
170
|
+
parser.add_argument("--batch_size", type=int, default=128)
|
171
|
+
parser.add_argument("--eval_frequency", type=int, default=100)
|
172
|
+
args = parser.parse_args()
|
173
|
+
output_dir = Path(args.output_dir)
|
174
|
+
output_dir.mkdir(exist_ok=True, parents=True)
|
175
|
+
metrics = main(args)
|
176
|
+
with open(output_dir / "time_token_metrics.json", "w") as f:
|
177
|
+
json.dump(metrics, f)
|
File without changes
|
@@ -0,0 +1,467 @@
|
|
1
|
+
import glob
|
2
|
+
import os
|
3
|
+
import shutil
|
4
|
+
import uuid
|
5
|
+
from datetime import datetime
|
6
|
+
from functools import partial
|
7
|
+
from pathlib import Path
|
8
|
+
from typing import Optional, Union
|
9
|
+
|
10
|
+
import numpy as np
|
11
|
+
import pandas as pd
|
12
|
+
import torch
|
13
|
+
import torch.distributed as dist
|
14
|
+
from cehrbert.data_generators.hf_data_generator.meds_utils import CacheFileCollector
|
15
|
+
from cehrbert.runners.runner_util import generate_prepared_ds_path
|
16
|
+
from datasets import concatenate_datasets, load_from_disk
|
17
|
+
from torch.utils.data import DataLoader
|
18
|
+
from tqdm import tqdm
|
19
|
+
from transformers.trainer_utils import is_main_process
|
20
|
+
from transformers.utils import is_flash_attn_2_available, logging
|
21
|
+
|
22
|
+
from cehrgpt.data.hf_cehrgpt_dataset import create_cehrgpt_finetuning_dataset
|
23
|
+
from cehrgpt.data.hf_cehrgpt_dataset_collator import (
|
24
|
+
CehrGptDataCollator,
|
25
|
+
SamplePackingCehrGptDataCollator,
|
26
|
+
)
|
27
|
+
from cehrgpt.data.sample_packing_sampler import SamplePackingBatchSampler
|
28
|
+
from cehrgpt.models.hf_cehrgpt import (
|
29
|
+
CEHRGPT2Model,
|
30
|
+
extract_features_from_packed_sequence,
|
31
|
+
)
|
32
|
+
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
33
|
+
from cehrgpt.runners.data_utils import prepare_finetune_dataset
|
34
|
+
from cehrgpt.runners.gpt_runner_util import parse_runner_args
|
35
|
+
from cehrgpt.runners.hf_cehrgpt_pretrain_runner import tokenizer_exists
|
36
|
+
|
37
|
+
LOG = logging.get_logger("transformers")
|
38
|
+
|
39
|
+
|
40
|
+
def get_torch_dtype(torch_dtype: Optional[str] = None) -> Union[torch.dtype, str]:
|
41
|
+
if torch_dtype and hasattr(torch, torch_dtype):
|
42
|
+
return getattr(torch, torch_dtype)
|
43
|
+
return torch.float32
|
44
|
+
|
45
|
+
|
46
|
+
def extract_averaged_embeddings_from_packed_sequence(
|
47
|
+
hidden_states: torch.Tensor,
|
48
|
+
attention_mask: torch.Tensor,
|
49
|
+
ve_token_indicators: torch.BoolTensor,
|
50
|
+
) -> torch.Tensor:
|
51
|
+
"""
|
52
|
+
Args:
|
53
|
+
|
54
|
+
hidden_states: (batch_size=1, seq_len, hidden_dim) tensor
|
55
|
+
attention_mask: (batch_size=1, seq_len) tensor, where 0 indicates padding
|
56
|
+
ve_token_indicators: (batch_size=1, seq_len) bool tensor, True if token is VE token
|
57
|
+
Returns:
|
58
|
+
(num_samples, hidden_dim) tensor: averaged embeddings over VE tokens for each sample
|
59
|
+
"""
|
60
|
+
# Step 1: Create segment IDs
|
61
|
+
mask = attention_mask[0] # (seq_len,)
|
62
|
+
segment_ids = (mask == 0).cumsum(dim=0) + 1 # start segment IDs from 1
|
63
|
+
segment_ids = (segment_ids * mask).to(torch.int32) # set PAD positions back to 0
|
64
|
+
|
65
|
+
# Step 2: Only keep tokens that are both valid and VE tokens
|
66
|
+
valid = (segment_ids > 0) & (ve_token_indicators[0])
|
67
|
+
valid_embeddings = hidden_states[0, valid].to(
|
68
|
+
torch.float32
|
69
|
+
) # (num_valid_ve_tokens, hidden_dim)
|
70
|
+
valid_segments = segment_ids[valid] # (num_valid_ve_tokens,)
|
71
|
+
|
72
|
+
# Step 3: Group by segment id and average
|
73
|
+
num_segments = int(segment_ids.max().item())
|
74
|
+
|
75
|
+
sample_embeddings = torch.zeros(
|
76
|
+
num_segments, hidden_states.size(-1), device=hidden_states.device
|
77
|
+
)
|
78
|
+
counts = torch.zeros(num_segments, device=hidden_states.device)
|
79
|
+
|
80
|
+
sample_embeddings.index_add_(0, valid_segments - 1, valid_embeddings)
|
81
|
+
counts.index_add_(
|
82
|
+
0, valid_segments - 1, torch.ones_like(valid_segments, dtype=counts.dtype)
|
83
|
+
)
|
84
|
+
|
85
|
+
# Avoid divide-by-zero (if some segments have no VE tokens, set their embeddings to zero)
|
86
|
+
counts = counts.masked_fill(counts == 0, 1.0)
|
87
|
+
|
88
|
+
sample_embeddings = sample_embeddings / counts.unsqueeze(-1)
|
89
|
+
|
90
|
+
return sample_embeddings
|
91
|
+
|
92
|
+
|
93
|
+
def main():
|
94
|
+
cehrgpt_args, data_args, model_args, training_args = parse_runner_args()
|
95
|
+
if torch.cuda.is_available():
|
96
|
+
device = torch.device("cuda")
|
97
|
+
else:
|
98
|
+
device = torch.device("cpu")
|
99
|
+
|
100
|
+
cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(
|
101
|
+
model_args.tokenizer_name_or_path
|
102
|
+
)
|
103
|
+
torch_dtype = get_torch_dtype(model_args.torch_dtype)
|
104
|
+
cehrgpt_model = (
|
105
|
+
CEHRGPT2Model.from_pretrained(
|
106
|
+
model_args.model_name_or_path,
|
107
|
+
attn_implementation=(
|
108
|
+
"flash_attention_2" if is_flash_attn_2_available() else "eager"
|
109
|
+
),
|
110
|
+
torch_dtype=torch_dtype,
|
111
|
+
)
|
112
|
+
.eval()
|
113
|
+
.to(device)
|
114
|
+
)
|
115
|
+
prepared_ds_path = generate_prepared_ds_path(
|
116
|
+
data_args, model_args, data_folder=data_args.cohort_folder
|
117
|
+
)
|
118
|
+
cache_file_collector = CacheFileCollector()
|
119
|
+
processed_dataset = None
|
120
|
+
if any(prepared_ds_path.glob("*")):
|
121
|
+
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
122
|
+
processed_dataset = load_from_disk(str(prepared_ds_path))
|
123
|
+
LOG.info("Prepared dataset loaded from disk...")
|
124
|
+
if cehrgpt_args.expand_tokenizer:
|
125
|
+
if tokenizer_exists(training_args.output_dir):
|
126
|
+
cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(
|
127
|
+
training_args.output_dir
|
128
|
+
)
|
129
|
+
else:
|
130
|
+
LOG.warning(
|
131
|
+
f"CehrGptTokenizer must exist in {training_args.output_dir} "
|
132
|
+
f"when the dataset has been processed and expand_tokenizer is set to True. "
|
133
|
+
f"Please delete the processed dataset at {prepared_ds_path}."
|
134
|
+
)
|
135
|
+
processed_dataset = None
|
136
|
+
shutil.rmtree(prepared_ds_path)
|
137
|
+
|
138
|
+
if processed_dataset is None:
|
139
|
+
if is_main_process(training_args.local_rank):
|
140
|
+
# Organize them into a single DatasetDict
|
141
|
+
final_splits = prepare_finetune_dataset(
|
142
|
+
data_args, training_args, cehrgpt_args, cache_file_collector
|
143
|
+
)
|
144
|
+
if cehrgpt_args.expand_tokenizer:
|
145
|
+
new_tokenizer_path = os.path.expanduser(training_args.output_dir)
|
146
|
+
if tokenizer_exists(new_tokenizer_path):
|
147
|
+
cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(
|
148
|
+
new_tokenizer_path
|
149
|
+
)
|
150
|
+
else:
|
151
|
+
cehrgpt_tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
|
152
|
+
cehrgpt_tokenizer=cehrgpt_tokenizer,
|
153
|
+
dataset=final_splits["train"],
|
154
|
+
data_args=data_args,
|
155
|
+
concept_name_mapping={},
|
156
|
+
)
|
157
|
+
cehrgpt_tokenizer.save_pretrained(
|
158
|
+
os.path.expanduser(training_args.output_dir)
|
159
|
+
)
|
160
|
+
|
161
|
+
# TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
|
162
|
+
if not data_args.streaming:
|
163
|
+
all_columns = final_splits["train"].column_names
|
164
|
+
if "visit_concept_ids" in all_columns:
|
165
|
+
final_splits = final_splits.remove_columns(["visit_concept_ids"])
|
166
|
+
|
167
|
+
processed_dataset = create_cehrgpt_finetuning_dataset(
|
168
|
+
dataset=final_splits,
|
169
|
+
cehrgpt_tokenizer=cehrgpt_tokenizer,
|
170
|
+
data_args=data_args,
|
171
|
+
cache_file_collector=cache_file_collector,
|
172
|
+
)
|
173
|
+
if not data_args.streaming:
|
174
|
+
processed_dataset.save_to_disk(prepared_ds_path)
|
175
|
+
processed_dataset.cleanup_cache_files()
|
176
|
+
|
177
|
+
# Remove all the cached files if processed_dataset.cleanup_cache_files() did not remove them already
|
178
|
+
cache_file_collector.remove_cache_files()
|
179
|
+
|
180
|
+
# After main-process-only operations, synchronize all processes to ensure consistency
|
181
|
+
if dist.is_available() and dist.is_initialized():
|
182
|
+
dist.barrier()
|
183
|
+
|
184
|
+
# Load the dataset from disk again to in torch distributed training
|
185
|
+
processed_dataset = load_from_disk(str(prepared_ds_path))
|
186
|
+
|
187
|
+
# Getting the existing features
|
188
|
+
feature_folders = glob.glob(
|
189
|
+
os.path.join(training_args.output_dir, "*", "features", "*.parquet")
|
190
|
+
)
|
191
|
+
if feature_folders:
|
192
|
+
existing_features = pd.concat(
|
193
|
+
[
|
194
|
+
pd.read_parquet(f, columns=["subject_id", "prediction_time_posix"])
|
195
|
+
for f in feature_folders
|
196
|
+
],
|
197
|
+
ignore_index=True,
|
198
|
+
)
|
199
|
+
subject_prediction_tuples = set(
|
200
|
+
existing_features.apply(
|
201
|
+
lambda row: f"{int(row['subject_id'])}-{int(row['prediction_time_posix'])}",
|
202
|
+
axis=1,
|
203
|
+
).tolist()
|
204
|
+
)
|
205
|
+
processed_dataset = processed_dataset.filter(
|
206
|
+
lambda _batch: [
|
207
|
+
f"{int(subject)}-{int(time)}" not in subject_prediction_tuples
|
208
|
+
for subject, time in zip(_batch["person_id"], _batch["index_date"])
|
209
|
+
],
|
210
|
+
num_proc=data_args.preprocessing_num_workers,
|
211
|
+
batch_size=data_args.preprocessing_batch_size,
|
212
|
+
batched=True,
|
213
|
+
)
|
214
|
+
LOG.info(
|
215
|
+
"The datasets after filtering (train: %s, validation: %s, test: %s)",
|
216
|
+
len(processed_dataset["train"]),
|
217
|
+
len(processed_dataset["validation"]),
|
218
|
+
len(processed_dataset["test"]),
|
219
|
+
)
|
220
|
+
|
221
|
+
LOG.info(f"cehrgpt_model.config.vocab_size: {cehrgpt_model.config.vocab_size}")
|
222
|
+
LOG.info(f"cehrgpt_tokenizer.vocab_size: {cehrgpt_tokenizer.vocab_size}")
|
223
|
+
if cehrgpt_model.config.vocab_size < cehrgpt_tokenizer.vocab_size:
|
224
|
+
cehrgpt_model.resize_token_embeddings(cehrgpt_tokenizer.vocab_size)
|
225
|
+
if (
|
226
|
+
cehrgpt_model.config.max_position_embeddings
|
227
|
+
< model_args.max_position_embeddings
|
228
|
+
):
|
229
|
+
LOG.info(
|
230
|
+
f"Increase model.config.max_position_embeddings to {model_args.max_position_embeddings}"
|
231
|
+
)
|
232
|
+
cehrgpt_model.config.max_position_embeddings = (
|
233
|
+
model_args.max_position_embeddings
|
234
|
+
)
|
235
|
+
cehrgpt_model.resize_position_embeddings(model_args.max_position_embeddings)
|
236
|
+
|
237
|
+
train_set = concatenate_datasets(
|
238
|
+
[processed_dataset["train"], processed_dataset["validation"]]
|
239
|
+
)
|
240
|
+
|
241
|
+
if cehrgpt_args.sample_packing:
|
242
|
+
per_device_eval_batch_size = 1
|
243
|
+
data_collator_fn = partial(
|
244
|
+
SamplePackingCehrGptDataCollator,
|
245
|
+
cehrgpt_args.max_tokens_per_batch,
|
246
|
+
cehrgpt_model.config.max_position_embeddings,
|
247
|
+
)
|
248
|
+
train_batch_sampler = SamplePackingBatchSampler(
|
249
|
+
lengths=train_set["num_of_concepts"],
|
250
|
+
max_tokens_per_batch=cehrgpt_args.max_tokens_per_batch,
|
251
|
+
max_position_embeddings=cehrgpt_model.config.max_position_embeddings,
|
252
|
+
drop_last=training_args.dataloader_drop_last,
|
253
|
+
seed=training_args.seed,
|
254
|
+
)
|
255
|
+
test_batch_sampler = SamplePackingBatchSampler(
|
256
|
+
lengths=processed_dataset["test"]["num_of_concepts"],
|
257
|
+
max_tokens_per_batch=cehrgpt_args.max_tokens_per_batch,
|
258
|
+
max_position_embeddings=cehrgpt_model.config.max_position_embeddings,
|
259
|
+
drop_last=training_args.dataloader_drop_last,
|
260
|
+
seed=training_args.seed,
|
261
|
+
)
|
262
|
+
else:
|
263
|
+
data_collator_fn = CehrGptDataCollator
|
264
|
+
train_batch_sampler = None
|
265
|
+
test_batch_sampler = None
|
266
|
+
per_device_eval_batch_size = training_args.per_device_eval_batch_size
|
267
|
+
|
268
|
+
# We suppress the additional learning objectives in fine-tuning
|
269
|
+
data_collator = data_collator_fn(
|
270
|
+
tokenizer=cehrgpt_tokenizer,
|
271
|
+
max_length=(
|
272
|
+
cehrgpt_args.max_tokens_per_batch
|
273
|
+
if cehrgpt_args.sample_packing
|
274
|
+
else model_args.max_position_embeddings
|
275
|
+
),
|
276
|
+
include_values=cehrgpt_model.config.include_values,
|
277
|
+
pretraining=False,
|
278
|
+
include_ttv_prediction=False,
|
279
|
+
use_sub_time_tokenization=False,
|
280
|
+
include_demographics=cehrgpt_args.include_demographics,
|
281
|
+
)
|
282
|
+
|
283
|
+
train_loader = DataLoader(
|
284
|
+
dataset=train_set,
|
285
|
+
batch_size=per_device_eval_batch_size,
|
286
|
+
num_workers=training_args.dataloader_num_workers,
|
287
|
+
collate_fn=data_collator,
|
288
|
+
pin_memory=training_args.dataloader_pin_memory,
|
289
|
+
batch_sampler=train_batch_sampler,
|
290
|
+
)
|
291
|
+
|
292
|
+
test_dataloader = DataLoader(
|
293
|
+
dataset=processed_dataset["test"],
|
294
|
+
batch_size=per_device_eval_batch_size,
|
295
|
+
num_workers=training_args.dataloader_num_workers,
|
296
|
+
collate_fn=data_collator,
|
297
|
+
pin_memory=training_args.dataloader_pin_memory,
|
298
|
+
batch_sampler=test_batch_sampler,
|
299
|
+
)
|
300
|
+
|
301
|
+
# Loading demographics
|
302
|
+
print("Loading demographics as a dictionary")
|
303
|
+
demographics_df = pd.concat(
|
304
|
+
[
|
305
|
+
pd.read_parquet(
|
306
|
+
data_dir,
|
307
|
+
columns=[
|
308
|
+
"person_id",
|
309
|
+
"index_date",
|
310
|
+
"gender_concept_id",
|
311
|
+
"race_concept_id",
|
312
|
+
],
|
313
|
+
)
|
314
|
+
for data_dir in [data_args.data_folder, data_args.test_data_folder]
|
315
|
+
]
|
316
|
+
)
|
317
|
+
demographics_df["index_date"] = demographics_df.index_date.dt.date
|
318
|
+
demographics_dict = {
|
319
|
+
(row["person_id"], row["index_date"]): {
|
320
|
+
"gender_concept_id": row["gender_concept_id"],
|
321
|
+
"race_concept_id": row["race_concept_id"],
|
322
|
+
}
|
323
|
+
for _, row in demographics_df.iterrows()
|
324
|
+
}
|
325
|
+
|
326
|
+
data_loaders = [("train", train_loader), ("test", test_dataloader)]
|
327
|
+
|
328
|
+
ve_token_id = cehrgpt_tokenizer._convert_token_to_id("[VE]")
|
329
|
+
for split, data_loader in data_loaders:
|
330
|
+
# Ensure prediction folder exists
|
331
|
+
feature_output_folder = (
|
332
|
+
Path(training_args.output_dir) / "features_with_label" / f"{split}_features"
|
333
|
+
)
|
334
|
+
feature_output_folder.mkdir(parents=True, exist_ok=True)
|
335
|
+
|
336
|
+
LOG.info("Generating features for %s set at %s", split, feature_output_folder)
|
337
|
+
|
338
|
+
with torch.no_grad():
|
339
|
+
for index, batch in enumerate(
|
340
|
+
tqdm(data_loader, desc="Generating features")
|
341
|
+
):
|
342
|
+
prediction_time_ages = (
|
343
|
+
batch.pop("age_at_index").numpy().astype(float).squeeze()
|
344
|
+
)
|
345
|
+
if prediction_time_ages.ndim == 0:
|
346
|
+
prediction_time_ages = np.asarray([prediction_time_ages])
|
347
|
+
|
348
|
+
person_ids = batch.pop("person_id").numpy().astype(int).squeeze()
|
349
|
+
if person_ids.ndim == 0:
|
350
|
+
person_ids = np.asarray([person_ids])
|
351
|
+
prediction_time_posix = batch.pop("index_date").numpy().squeeze()
|
352
|
+
if prediction_time_posix.ndim == 0:
|
353
|
+
prediction_time_posix = np.asarray([prediction_time_posix])
|
354
|
+
prediction_time = list(
|
355
|
+
map(datetime.fromtimestamp, prediction_time_posix)
|
356
|
+
)
|
357
|
+
labels = (
|
358
|
+
batch.pop("classifier_label")
|
359
|
+
.float()
|
360
|
+
.cpu()
|
361
|
+
.numpy()
|
362
|
+
.astype(bool)
|
363
|
+
.squeeze()
|
364
|
+
)
|
365
|
+
if labels.ndim == 0:
|
366
|
+
labels = np.asarray([labels])
|
367
|
+
|
368
|
+
batch = {k: v.to(device) for k, v in batch.items()}
|
369
|
+
# Forward pass
|
370
|
+
cehrgpt_output = cehrgpt_model(
|
371
|
+
**batch, output_attentions=False, output_hidden_states=False
|
372
|
+
)
|
373
|
+
if cehrgpt_args.sample_packing:
|
374
|
+
if cehrgpt_args.average_over_sequence:
|
375
|
+
ve_token_indicators: torch.BoolTensor = (
|
376
|
+
batch["input_ids"] == ve_token_id
|
377
|
+
)
|
378
|
+
features = (
|
379
|
+
extract_averaged_embeddings_from_packed_sequence(
|
380
|
+
cehrgpt_output.last_hidden_state,
|
381
|
+
batch["attention_mask"],
|
382
|
+
ve_token_indicators,
|
383
|
+
)
|
384
|
+
.cpu()
|
385
|
+
.float()
|
386
|
+
.detach()
|
387
|
+
.numpy()
|
388
|
+
)
|
389
|
+
else:
|
390
|
+
features = (
|
391
|
+
extract_features_from_packed_sequence(
|
392
|
+
cehrgpt_output.last_hidden_state,
|
393
|
+
batch["attention_mask"],
|
394
|
+
)
|
395
|
+
.cpu()
|
396
|
+
.float()
|
397
|
+
.detach()
|
398
|
+
.numpy()
|
399
|
+
.squeeze(axis=0)
|
400
|
+
)
|
401
|
+
else:
|
402
|
+
if cehrgpt_args.average_over_sequence:
|
403
|
+
features = torch.where(
|
404
|
+
batch["attention_mask"].unsqueeze(dim=-1).to(torch.bool),
|
405
|
+
cehrgpt_output.last_hidden_state,
|
406
|
+
0,
|
407
|
+
)
|
408
|
+
# Average across the sequence
|
409
|
+
features = features.mean(dim=1)
|
410
|
+
else:
|
411
|
+
last_end_token = any(
|
412
|
+
[
|
413
|
+
cehrgpt_tokenizer.end_token_id == input_id
|
414
|
+
for input_id in batch.pop("input_ids")
|
415
|
+
.cpu()
|
416
|
+
.numpy()
|
417
|
+
.squeeze()
|
418
|
+
.tolist()
|
419
|
+
]
|
420
|
+
)
|
421
|
+
last_token_index = -2 if last_end_token else -1
|
422
|
+
LOG.debug(
|
423
|
+
"The last token is [END], we need to use the token index before that: %s",
|
424
|
+
last_token_index,
|
425
|
+
)
|
426
|
+
features = (
|
427
|
+
cehrgpt_output.last_hidden_state[..., last_token_index, :]
|
428
|
+
.cpu()
|
429
|
+
.float()
|
430
|
+
.detach()
|
431
|
+
.numpy()
|
432
|
+
)
|
433
|
+
|
434
|
+
# Flatten features or handle them as a list of arrays (one array per row)
|
435
|
+
features_list = [feature for feature in features]
|
436
|
+
race_concept_ids = []
|
437
|
+
gender_concept_ids = []
|
438
|
+
for person_id, index_date in zip(person_ids, prediction_time):
|
439
|
+
key = (person_id, index_date.date())
|
440
|
+
if key in demographics_dict:
|
441
|
+
demographics = demographics_dict[key]
|
442
|
+
gender_concept_ids.append(demographics["gender_concept_id"])
|
443
|
+
race_concept_ids.append(demographics["race_concept_id"])
|
444
|
+
else:
|
445
|
+
gender_concept_ids.append(0)
|
446
|
+
race_concept_ids.append(0)
|
447
|
+
|
448
|
+
features_pd = pd.DataFrame(
|
449
|
+
{
|
450
|
+
"subject_id": person_ids,
|
451
|
+
"prediction_time": prediction_time,
|
452
|
+
"prediction_time_posix": prediction_time_posix,
|
453
|
+
"boolean_value": labels,
|
454
|
+
"age_at_index": prediction_time_ages,
|
455
|
+
}
|
456
|
+
)
|
457
|
+
# Adding features as a separate column where each row contains a feature array
|
458
|
+
features_pd["features"] = features_list
|
459
|
+
features_pd["race_concept_id"] = race_concept_ids
|
460
|
+
features_pd["gender_concept_id"] = gender_concept_ids
|
461
|
+
features_pd.to_parquet(
|
462
|
+
feature_output_folder / f"{uuid.uuid4()}.parquet"
|
463
|
+
)
|
464
|
+
|
465
|
+
|
466
|
+
if __name__ == "__main__":
|
467
|
+
main()
|