cehrgpt 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. cehrgpt/data/hf_cehrgpt_dataset.py +24 -4
  2. cehrgpt/data/hf_cehrgpt_dataset_collator.py +260 -84
  3. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +279 -2
  4. cehrgpt/data/sample_packing_sampler.py +151 -0
  5. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
  6. cehrgpt/generation/omop_converter_batch.py +3 -0
  7. cehrgpt/models/config.py +10 -0
  8. cehrgpt/models/hf_cehrgpt.py +244 -73
  9. cehrgpt/models/tokenization_hf_cehrgpt.py +6 -2
  10. cehrgpt/runners/data_utils.py +243 -0
  11. cehrgpt/runners/gpt_runner_util.py +0 -10
  12. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +154 -260
  13. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +250 -90
  14. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +46 -0
  15. cehrgpt/runners/hyperparameter_search_util.py +4 -1
  16. cehrgpt/runners/sample_packing_trainer.py +168 -0
  17. cehrgpt/simulations/__init__.py +0 -0
  18. cehrgpt/simulations/generate_plots.py +95 -0
  19. cehrgpt/simulations/run_simulation.sh +24 -0
  20. cehrgpt/simulations/time_embedding_simulation.py +250 -0
  21. cehrgpt/simulations/time_token_simulation.py +177 -0
  22. cehrgpt/tools/generate_causal_patient_split_by_age.py +146 -0
  23. cehrgpt/tools/linear_prob/__init__.py +0 -0
  24. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +467 -0
  25. cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
  26. {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/METADATA +57 -9
  27. {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/RECORD +30 -18
  28. {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/WHEEL +1 -1
  29. {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info/licenses}/LICENSE +0 -0
  30. {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,177 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.optim as optim
6
+ from sklearn.metrics import accuracy_score, roc_auc_score
7
+ from torch.nn import CrossEntropyLoss
8
+ from transformers import BertConfig, BertModel
9
+
10
+ from cehrgpt.simulations.time_embedding_simulation import generate_simulation_data
11
+
12
+
13
+ class ModelTimeToken(torch.nn.Module):
14
+ def __init__(self, vocab_size: int):
15
+ super(ModelTimeToken, self).__init__()
16
+ self.embedding = torch.nn.Embedding(vocab_size, 16)
17
+ self.bert = BertModel(
18
+ BertConfig(
19
+ vocab_size=vocab_size,
20
+ hidden_size=16,
21
+ num_attention_heads=2,
22
+ num_hidden_layers=2,
23
+ intermediate_size=32,
24
+ hidden_dropout_prob=0.0,
25
+ attention_probs_dropout_prob=0.0,
26
+ max_position_embeddings=3,
27
+ ),
28
+ add_pooling_layer=False,
29
+ )
30
+ self.linear = torch.nn.Linear(48, 2)
31
+
32
+ def forward(
33
+ self,
34
+ input_ids: torch.LongTensor,
35
+ labels: Optional[torch.LongTensor] = None,
36
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
37
+ bz = input_ids.shape[0]
38
+ x = self.embedding(input_ids)
39
+ bert_output = self.bert.forward(inputs_embeds=x, return_dict=True)
40
+ output = bert_output.last_hidden_state.reshape((bz, 48))
41
+ y = self.linear(output)
42
+ loss = None
43
+ if labels is not None:
44
+ loss_fct = CrossEntropyLoss()
45
+ loss = loss_fct(y, labels)
46
+ return loss, y
47
+
48
+
49
+ def create_time_token_tokenizer(simulated_data):
50
+ vocab = []
51
+ for row in simulated_data:
52
+ x1, x2, t1, t2, y = row
53
+ x1 = f"c-{x1}"
54
+ x2 = f"c-{x2}"
55
+ t = f"t-{t2 - t1}"
56
+ if x1 not in vocab:
57
+ vocab.append(x1)
58
+ if x2 not in vocab:
59
+ vocab.append(x2)
60
+ if t not in vocab:
61
+ vocab.append(t)
62
+ return {c: i + 1 for i, c in enumerate(vocab)}
63
+
64
+
65
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
+
67
+
68
+ def eval_step(simulated_data, time_token_tokenizer, time_embedding_model):
69
+ time_embedding_model.eval()
70
+ eval_input_ids = []
71
+ eval_y = []
72
+ for row in simulated_data:
73
+ x1, x2, t1, t2, y = row
74
+ x1 = f"c-{x1}"
75
+ x2 = f"c-{x2}"
76
+ t = f"t-{t2 - t1}"
77
+ eval_input_ids.append(
78
+ [
79
+ time_token_tokenizer[x1],
80
+ time_token_tokenizer[t],
81
+ time_token_tokenizer[x2],
82
+ ]
83
+ )
84
+ eval_y.append(y)
85
+ with torch.no_grad():
86
+ batched_input_ids = torch.tensor(eval_input_ids, dtype=torch.long).to(device)
87
+ batched_y = np.asarray(eval_y)
88
+ # Compute loss and forward pass
89
+ _, y_pred = time_embedding_model(batched_input_ids)
90
+ y_probs = torch.nn.functional.softmax(y_pred, dim=1)
91
+ y_probs = y_probs.detach().cpu().numpy()
92
+ roc_auc = roc_auc_score(batched_y, y_probs[:, 1])
93
+ accuracy = accuracy_score(batched_y, y_probs[:, 1] > y_probs[:, 0])
94
+ print(f"ROC AUC: {roc_auc}")
95
+ print(f"Accuracy: {accuracy}")
96
+ return accuracy, roc_auc
97
+
98
+
99
+ def train_step(
100
+ simulated_data, time_token_tokenizer, time_embedding_model, time_embedding_optimizer
101
+ ):
102
+ batched_input_ids = []
103
+ batched_y = []
104
+ indices = np.random.choice(simulated_data.shape[0], size=8, replace=False)
105
+ for row in simulated_data[indices, :]:
106
+ x1, x2, t1, t2, y = row
107
+ x1 = f"c-{x1}"
108
+ x2 = f"c-{x2}"
109
+ t = f"t-{t2 - t1}"
110
+ batched_input_ids.append(
111
+ [
112
+ time_token_tokenizer[x1],
113
+ time_token_tokenizer[t],
114
+ time_token_tokenizer[x2],
115
+ ]
116
+ )
117
+ batched_y.append(y)
118
+ batched_input_ids = torch.tensor(batched_input_ids, dtype=torch.long).to(device)
119
+ batched_y = torch.tensor(batched_y, dtype=torch.long).to(device)
120
+ # Zero the gradients
121
+ time_embedding_optimizer.zero_grad()
122
+ # Compute loss and forward pass
123
+ loss, _ = time_embedding_model(batched_input_ids, batched_y)
124
+ # Backward pass (compute gradients)
125
+ loss.backward()
126
+ # Update model parameters
127
+ time_embedding_optimizer.step()
128
+ return loss
129
+
130
+
131
+ def main(args):
132
+ simulated_data = generate_simulation_data(args.n_samples)
133
+ time_token_tokenizer = create_time_token_tokenizer(simulated_data)
134
+ time_embedding_model = ModelTimeToken(len(time_token_tokenizer) + 1).to(device)
135
+ time_embedding_optimizer = optim.Adam(time_embedding_model.parameters(), lr=0.001)
136
+ steps = []
137
+ roc_aucs = []
138
+ accuracies = []
139
+ for step in range(args.n_steps):
140
+ loss = train_step(
141
+ simulated_data,
142
+ time_token_tokenizer,
143
+ time_embedding_model,
144
+ time_embedding_optimizer,
145
+ )
146
+ print(f"Step {step}: Loss = {loss.item()}")
147
+ # Evaluation
148
+ if (
149
+ args.n_steps % args.eval_frequency == 0
150
+ and args.n_steps > args.eval_frequency
151
+ ):
152
+ accuracy, roc_auc = eval_step(
153
+ simulated_data, time_token_tokenizer, time_embedding_model
154
+ )
155
+ steps.append(step)
156
+ roc_aucs.append(roc_auc)
157
+ accuracies.append(accuracy)
158
+ return {"steps": steps, "roc_auc": roc_aucs, "accuracy": accuracies}
159
+
160
+
161
+ if __name__ == "__main__":
162
+ import argparse
163
+ import json
164
+ from pathlib import Path
165
+
166
+ parser = argparse.ArgumentParser("Model with time token simulation")
167
+ parser.add_argument("--output_dir", type=str, required=True)
168
+ parser.add_argument("--n_steps", type=int, default=10000)
169
+ parser.add_argument("--n_samples", type=int, default=1000)
170
+ parser.add_argument("--batch_size", type=int, default=128)
171
+ parser.add_argument("--eval_frequency", type=int, default=100)
172
+ args = parser.parse_args()
173
+ output_dir = Path(args.output_dir)
174
+ output_dir.mkdir(exist_ok=True, parents=True)
175
+ metrics = main(args)
176
+ with open(output_dir / "time_token_metrics.json", "w") as f:
177
+ json.dump(metrics, f)
@@ -0,0 +1,146 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+
4
+ # Define race mapping
5
+ race_mapping = {
6
+ "38003613": "8557",
7
+ "38003610": "8557",
8
+ "38003579": "8515",
9
+ "44814653": "0",
10
+ }
11
+
12
+ # Invalid age groups
13
+ invalid_age_groups = [
14
+ "age:100-110",
15
+ "age:110-120",
16
+ "age:120-130",
17
+ "age:130-140",
18
+ "age:140-150",
19
+ "age:150-160",
20
+ "age:160-170",
21
+ "age:170-180",
22
+ "age:180-190",
23
+ "age:190-200",
24
+ "age:640-650",
25
+ "age:680-690",
26
+ "age:730-740",
27
+ "age:740-750",
28
+ "age:890-900",
29
+ "age:900-910",
30
+ "age:-10-0",
31
+ ]
32
+
33
+
34
+ def age_group_func(age_str):
35
+ """
36
+ Categorize an age into a 10-year age group.
37
+
38
+ Args:
39
+ age_str (str): A string containing the age in the format "age:XX".
40
+
41
+ Returns:
42
+ str: A string representing the 10-year age group "age:XX-XX".
43
+ """
44
+ age = int(age_str.split(":")[1])
45
+ group_number = age // 10
46
+ return f"age:{group_number * 10}-{(group_number + 1) * 10}"
47
+
48
+
49
+ def map_race(race):
50
+ return race_mapping.get(race, race)
51
+
52
+
53
+ def main(args):
54
+ # Load data
55
+ patient_sequence = pd.read_parquet(args.patient_sequence)
56
+ # Extract and preprocess demographics
57
+ demographics = patient_sequence.concept_ids.apply(
58
+ lambda concept_ids: concept_ids[:4]
59
+ )
60
+ patient_sequence["demographics"] = demographics
61
+ year = demographics.apply(lambda concepts: concepts[0])
62
+ age = demographics.apply(lambda concepts: concepts[1]).apply(age_group_func)
63
+ gender = demographics.apply(lambda concepts: concepts[2])
64
+ race = demographics.apply(lambda concepts: concepts[3])
65
+ death = patient_sequence.concept_ids.apply(
66
+ lambda concept_ids: int(concept_ids[-2] == "[DEATH]")
67
+ )
68
+
69
+ patient_sequence["year"] = year
70
+ patient_sequence["age"] = age
71
+ patient_sequence["gender"] = gender
72
+ patient_sequence["race"] = race
73
+ patient_sequence["death"] = death
74
+
75
+ demographics = patient_sequence[
76
+ ["person_id", "death", "year", "age", "gender", "race", "split"]
77
+ ]
78
+ demographics["race"] = demographics.race.apply(map_race)
79
+
80
+ demographics_clean = demographics[
81
+ (demographics.gender != "0") & (~demographics.age.isin(invalid_age_groups))
82
+ ]
83
+ patient_sequence_clean = patient_sequence[
84
+ patient_sequence.person_id.isin(demographics_clean.person_id)
85
+ ]
86
+
87
+ # Calculate probabilities
88
+ probs = (
89
+ demographics_clean.groupby(["age"])["person_id"].count()
90
+ / len(demographics_clean)
91
+ ).reset_index()
92
+ probs.rename(columns={"person_id": "prob"}, inplace=True)
93
+
94
+ # Adjust probabilities
95
+ np.random.seed(42)
96
+ x = np.asarray(list(reversed(range(1, 11))))
97
+ adjusted_probs = probs.prob * pd.Series(x)
98
+ adjusted_probs = adjusted_probs / adjusted_probs.sum()
99
+ probs["adjusted_prob"] = adjusted_probs
100
+
101
+ demographics_for_sampling = patient_sequence_clean[
102
+ ["year", "age", "race", "gender", "person_id"]
103
+ ].merge(probs, on="age")
104
+ demographics_for_sampling["adjusted_prob"] = (
105
+ demographics_for_sampling.adjusted_prob
106
+ / demographics_for_sampling.adjusted_prob.sum()
107
+ )
108
+
109
+ # Train/Validation Split
110
+ causal_train_split = demographics_for_sampling.sample(
111
+ args.num_patients, replace=False, weights="adjusted_prob", random_state=1
112
+ )
113
+ causal_train_split["split"] = "train"
114
+ causal_val_split = demographics_for_sampling[
115
+ ~demographics_for_sampling.person_id.isin(causal_train_split.person_id)
116
+ ]
117
+ causal_val_split["split"] = "validation"
118
+
119
+ causal_train_val_split = pd.concat([causal_train_split, causal_val_split])
120
+
121
+ # Save outputs
122
+ causal_train_val_split.to_parquet(args.output_folder, index=False)
123
+
124
+
125
+ if __name__ == "__main__":
126
+ import argparse
127
+
128
+ parser = argparse.ArgumentParser(
129
+ description="Arguments for a causal patient split by age groups"
130
+ )
131
+ parser.add_argument(
132
+ "--patient_sequence",
133
+ required=True,
134
+ )
135
+ parser.add_argument(
136
+ "--num_patients",
137
+ default=1_000_000,
138
+ type=int,
139
+ required=False,
140
+ )
141
+ parser.add_argument(
142
+ "--output_folder",
143
+ required=True,
144
+ )
145
+ # Call the main function with parsed arguments
146
+ main(parser.parse_args())
File without changes