cehrgpt 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/data/hf_cehrgpt_dataset.py +24 -4
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +260 -84
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +279 -2
- cehrgpt/data/sample_packing_sampler.py +151 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
- cehrgpt/generation/omop_converter_batch.py +3 -0
- cehrgpt/models/config.py +10 -0
- cehrgpt/models/hf_cehrgpt.py +244 -73
- cehrgpt/models/tokenization_hf_cehrgpt.py +6 -2
- cehrgpt/runners/data_utils.py +243 -0
- cehrgpt/runners/gpt_runner_util.py +0 -10
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +154 -260
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +250 -90
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +46 -0
- cehrgpt/runners/hyperparameter_search_util.py +4 -1
- cehrgpt/runners/sample_packing_trainer.py +168 -0
- cehrgpt/simulations/__init__.py +0 -0
- cehrgpt/simulations/generate_plots.py +95 -0
- cehrgpt/simulations/run_simulation.sh +24 -0
- cehrgpt/simulations/time_embedding_simulation.py +250 -0
- cehrgpt/simulations/time_token_simulation.py +177 -0
- cehrgpt/tools/generate_causal_patient_split_by_age.py +146 -0
- cehrgpt/tools/linear_prob/__init__.py +0 -0
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +467 -0
- cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/METADATA +57 -9
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/RECORD +30 -18
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/WHEEL +1 -1
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info/licenses}/LICENSE +0 -0
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,177 @@
|
|
1
|
+
from typing import Optional, Tuple
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
import torch
|
5
|
+
import torch.optim as optim
|
6
|
+
from sklearn.metrics import accuracy_score, roc_auc_score
|
7
|
+
from torch.nn import CrossEntropyLoss
|
8
|
+
from transformers import BertConfig, BertModel
|
9
|
+
|
10
|
+
from cehrgpt.simulations.time_embedding_simulation import generate_simulation_data
|
11
|
+
|
12
|
+
|
13
|
+
class ModelTimeToken(torch.nn.Module):
|
14
|
+
def __init__(self, vocab_size: int):
|
15
|
+
super(ModelTimeToken, self).__init__()
|
16
|
+
self.embedding = torch.nn.Embedding(vocab_size, 16)
|
17
|
+
self.bert = BertModel(
|
18
|
+
BertConfig(
|
19
|
+
vocab_size=vocab_size,
|
20
|
+
hidden_size=16,
|
21
|
+
num_attention_heads=2,
|
22
|
+
num_hidden_layers=2,
|
23
|
+
intermediate_size=32,
|
24
|
+
hidden_dropout_prob=0.0,
|
25
|
+
attention_probs_dropout_prob=0.0,
|
26
|
+
max_position_embeddings=3,
|
27
|
+
),
|
28
|
+
add_pooling_layer=False,
|
29
|
+
)
|
30
|
+
self.linear = torch.nn.Linear(48, 2)
|
31
|
+
|
32
|
+
def forward(
|
33
|
+
self,
|
34
|
+
input_ids: torch.LongTensor,
|
35
|
+
labels: Optional[torch.LongTensor] = None,
|
36
|
+
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
37
|
+
bz = input_ids.shape[0]
|
38
|
+
x = self.embedding(input_ids)
|
39
|
+
bert_output = self.bert.forward(inputs_embeds=x, return_dict=True)
|
40
|
+
output = bert_output.last_hidden_state.reshape((bz, 48))
|
41
|
+
y = self.linear(output)
|
42
|
+
loss = None
|
43
|
+
if labels is not None:
|
44
|
+
loss_fct = CrossEntropyLoss()
|
45
|
+
loss = loss_fct(y, labels)
|
46
|
+
return loss, y
|
47
|
+
|
48
|
+
|
49
|
+
def create_time_token_tokenizer(simulated_data):
|
50
|
+
vocab = []
|
51
|
+
for row in simulated_data:
|
52
|
+
x1, x2, t1, t2, y = row
|
53
|
+
x1 = f"c-{x1}"
|
54
|
+
x2 = f"c-{x2}"
|
55
|
+
t = f"t-{t2 - t1}"
|
56
|
+
if x1 not in vocab:
|
57
|
+
vocab.append(x1)
|
58
|
+
if x2 not in vocab:
|
59
|
+
vocab.append(x2)
|
60
|
+
if t not in vocab:
|
61
|
+
vocab.append(t)
|
62
|
+
return {c: i + 1 for i, c in enumerate(vocab)}
|
63
|
+
|
64
|
+
|
65
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
66
|
+
|
67
|
+
|
68
|
+
def eval_step(simulated_data, time_token_tokenizer, time_embedding_model):
|
69
|
+
time_embedding_model.eval()
|
70
|
+
eval_input_ids = []
|
71
|
+
eval_y = []
|
72
|
+
for row in simulated_data:
|
73
|
+
x1, x2, t1, t2, y = row
|
74
|
+
x1 = f"c-{x1}"
|
75
|
+
x2 = f"c-{x2}"
|
76
|
+
t = f"t-{t2 - t1}"
|
77
|
+
eval_input_ids.append(
|
78
|
+
[
|
79
|
+
time_token_tokenizer[x1],
|
80
|
+
time_token_tokenizer[t],
|
81
|
+
time_token_tokenizer[x2],
|
82
|
+
]
|
83
|
+
)
|
84
|
+
eval_y.append(y)
|
85
|
+
with torch.no_grad():
|
86
|
+
batched_input_ids = torch.tensor(eval_input_ids, dtype=torch.long).to(device)
|
87
|
+
batched_y = np.asarray(eval_y)
|
88
|
+
# Compute loss and forward pass
|
89
|
+
_, y_pred = time_embedding_model(batched_input_ids)
|
90
|
+
y_probs = torch.nn.functional.softmax(y_pred, dim=1)
|
91
|
+
y_probs = y_probs.detach().cpu().numpy()
|
92
|
+
roc_auc = roc_auc_score(batched_y, y_probs[:, 1])
|
93
|
+
accuracy = accuracy_score(batched_y, y_probs[:, 1] > y_probs[:, 0])
|
94
|
+
print(f"ROC AUC: {roc_auc}")
|
95
|
+
print(f"Accuracy: {accuracy}")
|
96
|
+
return accuracy, roc_auc
|
97
|
+
|
98
|
+
|
99
|
+
def train_step(
|
100
|
+
simulated_data, time_token_tokenizer, time_embedding_model, time_embedding_optimizer
|
101
|
+
):
|
102
|
+
batched_input_ids = []
|
103
|
+
batched_y = []
|
104
|
+
indices = np.random.choice(simulated_data.shape[0], size=8, replace=False)
|
105
|
+
for row in simulated_data[indices, :]:
|
106
|
+
x1, x2, t1, t2, y = row
|
107
|
+
x1 = f"c-{x1}"
|
108
|
+
x2 = f"c-{x2}"
|
109
|
+
t = f"t-{t2 - t1}"
|
110
|
+
batched_input_ids.append(
|
111
|
+
[
|
112
|
+
time_token_tokenizer[x1],
|
113
|
+
time_token_tokenizer[t],
|
114
|
+
time_token_tokenizer[x2],
|
115
|
+
]
|
116
|
+
)
|
117
|
+
batched_y.append(y)
|
118
|
+
batched_input_ids = torch.tensor(batched_input_ids, dtype=torch.long).to(device)
|
119
|
+
batched_y = torch.tensor(batched_y, dtype=torch.long).to(device)
|
120
|
+
# Zero the gradients
|
121
|
+
time_embedding_optimizer.zero_grad()
|
122
|
+
# Compute loss and forward pass
|
123
|
+
loss, _ = time_embedding_model(batched_input_ids, batched_y)
|
124
|
+
# Backward pass (compute gradients)
|
125
|
+
loss.backward()
|
126
|
+
# Update model parameters
|
127
|
+
time_embedding_optimizer.step()
|
128
|
+
return loss
|
129
|
+
|
130
|
+
|
131
|
+
def main(args):
|
132
|
+
simulated_data = generate_simulation_data(args.n_samples)
|
133
|
+
time_token_tokenizer = create_time_token_tokenizer(simulated_data)
|
134
|
+
time_embedding_model = ModelTimeToken(len(time_token_tokenizer) + 1).to(device)
|
135
|
+
time_embedding_optimizer = optim.Adam(time_embedding_model.parameters(), lr=0.001)
|
136
|
+
steps = []
|
137
|
+
roc_aucs = []
|
138
|
+
accuracies = []
|
139
|
+
for step in range(args.n_steps):
|
140
|
+
loss = train_step(
|
141
|
+
simulated_data,
|
142
|
+
time_token_tokenizer,
|
143
|
+
time_embedding_model,
|
144
|
+
time_embedding_optimizer,
|
145
|
+
)
|
146
|
+
print(f"Step {step}: Loss = {loss.item()}")
|
147
|
+
# Evaluation
|
148
|
+
if (
|
149
|
+
args.n_steps % args.eval_frequency == 0
|
150
|
+
and args.n_steps > args.eval_frequency
|
151
|
+
):
|
152
|
+
accuracy, roc_auc = eval_step(
|
153
|
+
simulated_data, time_token_tokenizer, time_embedding_model
|
154
|
+
)
|
155
|
+
steps.append(step)
|
156
|
+
roc_aucs.append(roc_auc)
|
157
|
+
accuracies.append(accuracy)
|
158
|
+
return {"steps": steps, "roc_auc": roc_aucs, "accuracy": accuracies}
|
159
|
+
|
160
|
+
|
161
|
+
if __name__ == "__main__":
|
162
|
+
import argparse
|
163
|
+
import json
|
164
|
+
from pathlib import Path
|
165
|
+
|
166
|
+
parser = argparse.ArgumentParser("Model with time token simulation")
|
167
|
+
parser.add_argument("--output_dir", type=str, required=True)
|
168
|
+
parser.add_argument("--n_steps", type=int, default=10000)
|
169
|
+
parser.add_argument("--n_samples", type=int, default=1000)
|
170
|
+
parser.add_argument("--batch_size", type=int, default=128)
|
171
|
+
parser.add_argument("--eval_frequency", type=int, default=100)
|
172
|
+
args = parser.parse_args()
|
173
|
+
output_dir = Path(args.output_dir)
|
174
|
+
output_dir.mkdir(exist_ok=True, parents=True)
|
175
|
+
metrics = main(args)
|
176
|
+
with open(output_dir / "time_token_metrics.json", "w") as f:
|
177
|
+
json.dump(metrics, f)
|
@@ -0,0 +1,146 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import pandas as pd
|
3
|
+
|
4
|
+
# Define race mapping
|
5
|
+
race_mapping = {
|
6
|
+
"38003613": "8557",
|
7
|
+
"38003610": "8557",
|
8
|
+
"38003579": "8515",
|
9
|
+
"44814653": "0",
|
10
|
+
}
|
11
|
+
|
12
|
+
# Invalid age groups
|
13
|
+
invalid_age_groups = [
|
14
|
+
"age:100-110",
|
15
|
+
"age:110-120",
|
16
|
+
"age:120-130",
|
17
|
+
"age:130-140",
|
18
|
+
"age:140-150",
|
19
|
+
"age:150-160",
|
20
|
+
"age:160-170",
|
21
|
+
"age:170-180",
|
22
|
+
"age:180-190",
|
23
|
+
"age:190-200",
|
24
|
+
"age:640-650",
|
25
|
+
"age:680-690",
|
26
|
+
"age:730-740",
|
27
|
+
"age:740-750",
|
28
|
+
"age:890-900",
|
29
|
+
"age:900-910",
|
30
|
+
"age:-10-0",
|
31
|
+
]
|
32
|
+
|
33
|
+
|
34
|
+
def age_group_func(age_str):
|
35
|
+
"""
|
36
|
+
Categorize an age into a 10-year age group.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
age_str (str): A string containing the age in the format "age:XX".
|
40
|
+
|
41
|
+
Returns:
|
42
|
+
str: A string representing the 10-year age group "age:XX-XX".
|
43
|
+
"""
|
44
|
+
age = int(age_str.split(":")[1])
|
45
|
+
group_number = age // 10
|
46
|
+
return f"age:{group_number * 10}-{(group_number + 1) * 10}"
|
47
|
+
|
48
|
+
|
49
|
+
def map_race(race):
|
50
|
+
return race_mapping.get(race, race)
|
51
|
+
|
52
|
+
|
53
|
+
def main(args):
|
54
|
+
# Load data
|
55
|
+
patient_sequence = pd.read_parquet(args.patient_sequence)
|
56
|
+
# Extract and preprocess demographics
|
57
|
+
demographics = patient_sequence.concept_ids.apply(
|
58
|
+
lambda concept_ids: concept_ids[:4]
|
59
|
+
)
|
60
|
+
patient_sequence["demographics"] = demographics
|
61
|
+
year = demographics.apply(lambda concepts: concepts[0])
|
62
|
+
age = demographics.apply(lambda concepts: concepts[1]).apply(age_group_func)
|
63
|
+
gender = demographics.apply(lambda concepts: concepts[2])
|
64
|
+
race = demographics.apply(lambda concepts: concepts[3])
|
65
|
+
death = patient_sequence.concept_ids.apply(
|
66
|
+
lambda concept_ids: int(concept_ids[-2] == "[DEATH]")
|
67
|
+
)
|
68
|
+
|
69
|
+
patient_sequence["year"] = year
|
70
|
+
patient_sequence["age"] = age
|
71
|
+
patient_sequence["gender"] = gender
|
72
|
+
patient_sequence["race"] = race
|
73
|
+
patient_sequence["death"] = death
|
74
|
+
|
75
|
+
demographics = patient_sequence[
|
76
|
+
["person_id", "death", "year", "age", "gender", "race", "split"]
|
77
|
+
]
|
78
|
+
demographics["race"] = demographics.race.apply(map_race)
|
79
|
+
|
80
|
+
demographics_clean = demographics[
|
81
|
+
(demographics.gender != "0") & (~demographics.age.isin(invalid_age_groups))
|
82
|
+
]
|
83
|
+
patient_sequence_clean = patient_sequence[
|
84
|
+
patient_sequence.person_id.isin(demographics_clean.person_id)
|
85
|
+
]
|
86
|
+
|
87
|
+
# Calculate probabilities
|
88
|
+
probs = (
|
89
|
+
demographics_clean.groupby(["age"])["person_id"].count()
|
90
|
+
/ len(demographics_clean)
|
91
|
+
).reset_index()
|
92
|
+
probs.rename(columns={"person_id": "prob"}, inplace=True)
|
93
|
+
|
94
|
+
# Adjust probabilities
|
95
|
+
np.random.seed(42)
|
96
|
+
x = np.asarray(list(reversed(range(1, 11))))
|
97
|
+
adjusted_probs = probs.prob * pd.Series(x)
|
98
|
+
adjusted_probs = adjusted_probs / adjusted_probs.sum()
|
99
|
+
probs["adjusted_prob"] = adjusted_probs
|
100
|
+
|
101
|
+
demographics_for_sampling = patient_sequence_clean[
|
102
|
+
["year", "age", "race", "gender", "person_id"]
|
103
|
+
].merge(probs, on="age")
|
104
|
+
demographics_for_sampling["adjusted_prob"] = (
|
105
|
+
demographics_for_sampling.adjusted_prob
|
106
|
+
/ demographics_for_sampling.adjusted_prob.sum()
|
107
|
+
)
|
108
|
+
|
109
|
+
# Train/Validation Split
|
110
|
+
causal_train_split = demographics_for_sampling.sample(
|
111
|
+
args.num_patients, replace=False, weights="adjusted_prob", random_state=1
|
112
|
+
)
|
113
|
+
causal_train_split["split"] = "train"
|
114
|
+
causal_val_split = demographics_for_sampling[
|
115
|
+
~demographics_for_sampling.person_id.isin(causal_train_split.person_id)
|
116
|
+
]
|
117
|
+
causal_val_split["split"] = "validation"
|
118
|
+
|
119
|
+
causal_train_val_split = pd.concat([causal_train_split, causal_val_split])
|
120
|
+
|
121
|
+
# Save outputs
|
122
|
+
causal_train_val_split.to_parquet(args.output_folder, index=False)
|
123
|
+
|
124
|
+
|
125
|
+
if __name__ == "__main__":
|
126
|
+
import argparse
|
127
|
+
|
128
|
+
parser = argparse.ArgumentParser(
|
129
|
+
description="Arguments for a causal patient split by age groups"
|
130
|
+
)
|
131
|
+
parser.add_argument(
|
132
|
+
"--patient_sequence",
|
133
|
+
required=True,
|
134
|
+
)
|
135
|
+
parser.add_argument(
|
136
|
+
"--num_patients",
|
137
|
+
default=1_000_000,
|
138
|
+
type=int,
|
139
|
+
required=False,
|
140
|
+
)
|
141
|
+
parser.add_argument(
|
142
|
+
"--output_folder",
|
143
|
+
required=True,
|
144
|
+
)
|
145
|
+
# Call the main function with parsed arguments
|
146
|
+
main(parser.parse_args())
|
File without changes
|