cehrgpt 0.0.2__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/irregularity.py +36 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +25 -4
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +635 -97
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +308 -95
- cehrgpt/data/sample_packing_sampler.py +181 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
- cehrgpt/generation/omop_converter_batch.py +32 -2
- cehrgpt/gpt_utils.py +20 -2
- cehrgpt/models/config.py +35 -0
- cehrgpt/models/hf_cehrgpt.py +470 -106
- cehrgpt/models/hf_modeling_outputs.py +1 -0
- cehrgpt/models/special_tokens.py +1 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +358 -71
- cehrgpt/runners/data_utils.py +358 -0
- cehrgpt/runners/gpt_runner_util.py +0 -10
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +181 -283
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +288 -112
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +90 -0
- cehrgpt/runners/hyperparameter_search_util.py +10 -8
- cehrgpt/runners/sample_packing_trainer.py +185 -0
- cehrgpt/simulations/generate_plots.py +95 -0
- cehrgpt/simulations/run_simulation.sh +24 -0
- cehrgpt/simulations/time_embedding_simulation.py +250 -0
- cehrgpt/simulations/time_token_simulation.py +177 -0
- cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
- cehrgpt/time_to_event/time_to_event_model.py +2 -13
- cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
- cehrgpt/tools/linear_prob/__init__.py +0 -0
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +495 -0
- cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/METADATA +11 -8
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/RECORD +36 -32
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/WHEEL +1 -1
- cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
- cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
- cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
- cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
- cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
- cehrgpt/rl_finetune/ppo_finetune.py +0 -394
- cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
- cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
- /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info/licenses}/LICENSE +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,177 @@
|
|
1
|
+
from typing import Optional, Tuple
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
import torch
|
5
|
+
import torch.optim as optim
|
6
|
+
from sklearn.metrics import accuracy_score, roc_auc_score
|
7
|
+
from torch.nn import CrossEntropyLoss
|
8
|
+
from transformers import BertConfig, BertModel
|
9
|
+
|
10
|
+
from cehrgpt.simulations.time_embedding_simulation import generate_simulation_data
|
11
|
+
|
12
|
+
|
13
|
+
class ModelTimeToken(torch.nn.Module):
|
14
|
+
def __init__(self, vocab_size: int):
|
15
|
+
super(ModelTimeToken, self).__init__()
|
16
|
+
self.embedding = torch.nn.Embedding(vocab_size, 16)
|
17
|
+
self.bert = BertModel(
|
18
|
+
BertConfig(
|
19
|
+
vocab_size=vocab_size,
|
20
|
+
hidden_size=16,
|
21
|
+
num_attention_heads=2,
|
22
|
+
num_hidden_layers=2,
|
23
|
+
intermediate_size=32,
|
24
|
+
hidden_dropout_prob=0.0,
|
25
|
+
attention_probs_dropout_prob=0.0,
|
26
|
+
max_position_embeddings=3,
|
27
|
+
),
|
28
|
+
add_pooling_layer=False,
|
29
|
+
)
|
30
|
+
self.linear = torch.nn.Linear(48, 2)
|
31
|
+
|
32
|
+
def forward(
|
33
|
+
self,
|
34
|
+
input_ids: torch.LongTensor,
|
35
|
+
labels: Optional[torch.LongTensor] = None,
|
36
|
+
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
37
|
+
bz = input_ids.shape[0]
|
38
|
+
x = self.embedding(input_ids)
|
39
|
+
bert_output = self.bert.forward(inputs_embeds=x, return_dict=True)
|
40
|
+
output = bert_output.last_hidden_state.reshape((bz, 48))
|
41
|
+
y = self.linear(output)
|
42
|
+
loss = None
|
43
|
+
if labels is not None:
|
44
|
+
loss_fct = CrossEntropyLoss()
|
45
|
+
loss = loss_fct(y, labels)
|
46
|
+
return loss, y
|
47
|
+
|
48
|
+
|
49
|
+
def create_time_token_tokenizer(simulated_data):
|
50
|
+
vocab = []
|
51
|
+
for row in simulated_data:
|
52
|
+
x1, x2, t1, t2, y = row
|
53
|
+
x1 = f"c-{x1}"
|
54
|
+
x2 = f"c-{x2}"
|
55
|
+
t = f"t-{t2 - t1}"
|
56
|
+
if x1 not in vocab:
|
57
|
+
vocab.append(x1)
|
58
|
+
if x2 not in vocab:
|
59
|
+
vocab.append(x2)
|
60
|
+
if t not in vocab:
|
61
|
+
vocab.append(t)
|
62
|
+
return {c: i + 1 for i, c in enumerate(vocab)}
|
63
|
+
|
64
|
+
|
65
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
66
|
+
|
67
|
+
|
68
|
+
def eval_step(simulated_data, time_token_tokenizer, time_embedding_model):
|
69
|
+
time_embedding_model.eval()
|
70
|
+
eval_input_ids = []
|
71
|
+
eval_y = []
|
72
|
+
for row in simulated_data:
|
73
|
+
x1, x2, t1, t2, y = row
|
74
|
+
x1 = f"c-{x1}"
|
75
|
+
x2 = f"c-{x2}"
|
76
|
+
t = f"t-{t2 - t1}"
|
77
|
+
eval_input_ids.append(
|
78
|
+
[
|
79
|
+
time_token_tokenizer[x1],
|
80
|
+
time_token_tokenizer[t],
|
81
|
+
time_token_tokenizer[x2],
|
82
|
+
]
|
83
|
+
)
|
84
|
+
eval_y.append(y)
|
85
|
+
with torch.no_grad():
|
86
|
+
batched_input_ids = torch.tensor(eval_input_ids, dtype=torch.long).to(device)
|
87
|
+
batched_y = np.asarray(eval_y)
|
88
|
+
# Compute loss and forward pass
|
89
|
+
_, y_pred = time_embedding_model(batched_input_ids)
|
90
|
+
y_probs = torch.nn.functional.softmax(y_pred, dim=1)
|
91
|
+
y_probs = y_probs.detach().cpu().numpy()
|
92
|
+
roc_auc = roc_auc_score(batched_y, y_probs[:, 1])
|
93
|
+
accuracy = accuracy_score(batched_y, y_probs[:, 1] > y_probs[:, 0])
|
94
|
+
print(f"ROC AUC: {roc_auc}")
|
95
|
+
print(f"Accuracy: {accuracy}")
|
96
|
+
return accuracy, roc_auc
|
97
|
+
|
98
|
+
|
99
|
+
def train_step(
|
100
|
+
simulated_data, time_token_tokenizer, time_embedding_model, time_embedding_optimizer
|
101
|
+
):
|
102
|
+
batched_input_ids = []
|
103
|
+
batched_y = []
|
104
|
+
indices = np.random.choice(simulated_data.shape[0], size=8, replace=False)
|
105
|
+
for row in simulated_data[indices, :]:
|
106
|
+
x1, x2, t1, t2, y = row
|
107
|
+
x1 = f"c-{x1}"
|
108
|
+
x2 = f"c-{x2}"
|
109
|
+
t = f"t-{t2 - t1}"
|
110
|
+
batched_input_ids.append(
|
111
|
+
[
|
112
|
+
time_token_tokenizer[x1],
|
113
|
+
time_token_tokenizer[t],
|
114
|
+
time_token_tokenizer[x2],
|
115
|
+
]
|
116
|
+
)
|
117
|
+
batched_y.append(y)
|
118
|
+
batched_input_ids = torch.tensor(batched_input_ids, dtype=torch.long).to(device)
|
119
|
+
batched_y = torch.tensor(batched_y, dtype=torch.long).to(device)
|
120
|
+
# Zero the gradients
|
121
|
+
time_embedding_optimizer.zero_grad()
|
122
|
+
# Compute loss and forward pass
|
123
|
+
loss, _ = time_embedding_model(batched_input_ids, batched_y)
|
124
|
+
# Backward pass (compute gradients)
|
125
|
+
loss.backward()
|
126
|
+
# Update model parameters
|
127
|
+
time_embedding_optimizer.step()
|
128
|
+
return loss
|
129
|
+
|
130
|
+
|
131
|
+
def main(args):
|
132
|
+
simulated_data = generate_simulation_data(args.n_samples)
|
133
|
+
time_token_tokenizer = create_time_token_tokenizer(simulated_data)
|
134
|
+
time_embedding_model = ModelTimeToken(len(time_token_tokenizer) + 1).to(device)
|
135
|
+
time_embedding_optimizer = optim.Adam(time_embedding_model.parameters(), lr=0.001)
|
136
|
+
steps = []
|
137
|
+
roc_aucs = []
|
138
|
+
accuracies = []
|
139
|
+
for step in range(args.n_steps):
|
140
|
+
loss = train_step(
|
141
|
+
simulated_data,
|
142
|
+
time_token_tokenizer,
|
143
|
+
time_embedding_model,
|
144
|
+
time_embedding_optimizer,
|
145
|
+
)
|
146
|
+
print(f"Step {step}: Loss = {loss.item()}")
|
147
|
+
# Evaluation
|
148
|
+
if (
|
149
|
+
args.n_steps % args.eval_frequency == 0
|
150
|
+
and args.n_steps > args.eval_frequency
|
151
|
+
):
|
152
|
+
accuracy, roc_auc = eval_step(
|
153
|
+
simulated_data, time_token_tokenizer, time_embedding_model
|
154
|
+
)
|
155
|
+
steps.append(step)
|
156
|
+
roc_aucs.append(roc_auc)
|
157
|
+
accuracies.append(accuracy)
|
158
|
+
return {"steps": steps, "roc_auc": roc_aucs, "accuracy": accuracies}
|
159
|
+
|
160
|
+
|
161
|
+
if __name__ == "__main__":
|
162
|
+
import argparse
|
163
|
+
import json
|
164
|
+
from pathlib import Path
|
165
|
+
|
166
|
+
parser = argparse.ArgumentParser("Model with time token simulation")
|
167
|
+
parser.add_argument("--output_dir", type=str, required=True)
|
168
|
+
parser.add_argument("--n_steps", type=int, default=10000)
|
169
|
+
parser.add_argument("--n_samples", type=int, default=1000)
|
170
|
+
parser.add_argument("--batch_size", type=int, default=128)
|
171
|
+
parser.add_argument("--eval_frequency", type=int, default=100)
|
172
|
+
args = parser.parse_args()
|
173
|
+
output_dir = Path(args.output_dir)
|
174
|
+
output_dir.mkdir(exist_ok=True, parents=True)
|
175
|
+
metrics = main(args)
|
176
|
+
with open(output_dir / "time_token_metrics.json", "w") as f:
|
177
|
+
json.dump(metrics, f)
|
@@ -0,0 +1,23 @@
|
|
1
|
+
task_name: "cabg_prediction"
|
2
|
+
outcome_events: [
|
3
|
+
"43528001",
|
4
|
+
"43528003",
|
5
|
+
"43528004",
|
6
|
+
"43528002",
|
7
|
+
"4305852",
|
8
|
+
"4168831",
|
9
|
+
"2107250",
|
10
|
+
"2107216",
|
11
|
+
"2107222",
|
12
|
+
"2107231",
|
13
|
+
"4336464",
|
14
|
+
"4231998",
|
15
|
+
"4284104",
|
16
|
+
"2100873",
|
17
|
+
]
|
18
|
+
future_visit_start: 0
|
19
|
+
future_visit_end: -1
|
20
|
+
prediction_window_start: 0
|
21
|
+
prediction_window_end: 365
|
22
|
+
max_new_tokens: 1024
|
23
|
+
include_descendants: true
|
@@ -80,20 +80,9 @@ class TimeToEventModel:
|
|
80
80
|
return token in self.outcome_events
|
81
81
|
|
82
82
|
def simulate(
|
83
|
-
self,
|
83
|
+
self,
|
84
|
+
partial_history: Union[np.ndarray, List[str]],
|
84
85
|
) -> List[List[str]]:
|
85
|
-
|
86
|
-
sequence_is_demographics = len(partial_history) == 4 and partial_history[
|
87
|
-
0
|
88
|
-
].startswith("year")
|
89
|
-
sequence_ends_ve = is_visit_end(partial_history[-1])
|
90
|
-
|
91
|
-
if not (sequence_is_demographics | sequence_ends_ve):
|
92
|
-
raise ValueError(
|
93
|
-
"There are only two types of sequences allowed. 1) the sequence only contains "
|
94
|
-
"demographics; 2) the sequence ends on VE;"
|
95
|
-
)
|
96
|
-
|
97
86
|
token_ids = self.tokenizer.encode(partial_history)
|
98
87
|
prompt = torch.tensor(token_ids).unsqueeze(0).to(self.device)
|
99
88
|
|
@@ -118,9 +118,9 @@ def main(args):
|
|
118
118
|
LOG.info(f"Top P {args.top_p}")
|
119
119
|
LOG.info(f"Top K {args.top_k}")
|
120
120
|
|
121
|
-
cehrgpt_model.resize_position_embeddings(
|
122
|
-
|
123
|
-
)
|
121
|
+
# cehrgpt_model.resize_position_embeddings(
|
122
|
+
# cehrgpt_model.config.max_position_embeddings + task_config.max_new_tokens
|
123
|
+
# )
|
124
124
|
|
125
125
|
generation_config = TimeToEventModel.get_generation_config(
|
126
126
|
tokenizer=cehrgpt_tokenizer,
|
@@ -190,14 +190,22 @@ def main(args):
|
|
190
190
|
args.max_n_trial,
|
191
191
|
)
|
192
192
|
visit_counter = sum([int(is_visit_end(_)) for _ in partial_history])
|
193
|
+
predicted_boolean_probability = (
|
194
|
+
sum([event != "0" for event in concept_time_to_event.outcome_events])
|
195
|
+
/ len(concept_time_to_event.outcome_events)
|
196
|
+
if concept_time_to_event
|
197
|
+
else 0.0
|
198
|
+
)
|
193
199
|
tte_outputs.append(
|
194
200
|
{
|
195
|
-
"
|
196
|
-
"
|
201
|
+
"subject_id": record["person_id"],
|
202
|
+
"prediction_time": record["index_date"],
|
197
203
|
"visit_counter": visit_counter,
|
198
|
-
"
|
204
|
+
"boolean_value": label,
|
205
|
+
"predicted_boolean_probability": predicted_boolean_probability,
|
206
|
+
"predicted_boolean_value": None,
|
199
207
|
"time_to_event": time_to_event,
|
200
|
-
"
|
208
|
+
"trials": (
|
201
209
|
asdict(concept_time_to_event) if concept_time_to_event else None
|
202
210
|
),
|
203
211
|
}
|
@@ -263,9 +271,13 @@ def filter_out_existing_results(
|
|
263
271
|
parquet_files = glob.glob(os.path.join(prediction_output_folder_name, "*parquet"))
|
264
272
|
if parquet_files:
|
265
273
|
cohort_members = set()
|
266
|
-
results_dataframe = pd.read_parquet(parquet_files)[
|
274
|
+
results_dataframe = pd.read_parquet(parquet_files)[
|
275
|
+
["subject_id", "prediction_time"]
|
276
|
+
]
|
267
277
|
for row in results_dataframe.itertuples():
|
268
|
-
cohort_members.add(
|
278
|
+
cohort_members.add(
|
279
|
+
(row.subject_id, row.prediction_time.strftime("%Y-%m-%d"))
|
280
|
+
)
|
269
281
|
|
270
282
|
def filter_func(batched):
|
271
283
|
return [
|
@@ -292,12 +304,14 @@ def flush_to_disk_if_full(
|
|
292
304
|
pd.DataFrame(
|
293
305
|
tte_outputs,
|
294
306
|
columns=[
|
295
|
-
"
|
296
|
-
"
|
307
|
+
"subject_id",
|
308
|
+
"prediction_time",
|
297
309
|
"visit_counter",
|
298
|
-
"
|
310
|
+
"boolean_value",
|
311
|
+
"predicted_boolean_probability",
|
312
|
+
"predicted_boolean_value",
|
299
313
|
"time_to_event",
|
300
|
-
"
|
314
|
+
"trials",
|
301
315
|
],
|
302
316
|
).to_parquet(output_parquet_file)
|
303
317
|
tte_outputs.clear()
|
File without changes
|