cehrgpt 0.0.2__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. cehrgpt/analysis/irregularity.py +36 -0
  2. cehrgpt/data/hf_cehrgpt_dataset.py +25 -4
  3. cehrgpt/data/hf_cehrgpt_dataset_collator.py +635 -97
  4. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +308 -95
  5. cehrgpt/data/sample_packing_sampler.py +181 -0
  6. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
  7. cehrgpt/generation/omop_converter_batch.py +32 -2
  8. cehrgpt/gpt_utils.py +20 -2
  9. cehrgpt/models/config.py +35 -0
  10. cehrgpt/models/hf_cehrgpt.py +470 -106
  11. cehrgpt/models/hf_modeling_outputs.py +1 -0
  12. cehrgpt/models/special_tokens.py +1 -0
  13. cehrgpt/models/tokenization_hf_cehrgpt.py +358 -71
  14. cehrgpt/runners/data_utils.py +358 -0
  15. cehrgpt/runners/gpt_runner_util.py +0 -10
  16. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +181 -283
  17. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +288 -112
  18. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +90 -0
  19. cehrgpt/runners/hyperparameter_search_util.py +10 -8
  20. cehrgpt/runners/sample_packing_trainer.py +185 -0
  21. cehrgpt/simulations/generate_plots.py +95 -0
  22. cehrgpt/simulations/run_simulation.sh +24 -0
  23. cehrgpt/simulations/time_embedding_simulation.py +250 -0
  24. cehrgpt/simulations/time_token_simulation.py +177 -0
  25. cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
  26. cehrgpt/time_to_event/time_to_event_model.py +2 -13
  27. cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
  28. cehrgpt/tools/linear_prob/__init__.py +0 -0
  29. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +495 -0
  30. cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
  31. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/METADATA +11 -8
  32. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/RECORD +36 -32
  33. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/WHEEL +1 -1
  34. cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
  35. cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
  36. cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
  37. cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
  38. cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
  39. cehrgpt/rl_finetune/ppo_finetune.py +0 -394
  40. cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
  41. cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
  42. /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
  43. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info/licenses}/LICENSE +0 -0
  44. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,177 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.optim as optim
6
+ from sklearn.metrics import accuracy_score, roc_auc_score
7
+ from torch.nn import CrossEntropyLoss
8
+ from transformers import BertConfig, BertModel
9
+
10
+ from cehrgpt.simulations.time_embedding_simulation import generate_simulation_data
11
+
12
+
13
+ class ModelTimeToken(torch.nn.Module):
14
+ def __init__(self, vocab_size: int):
15
+ super(ModelTimeToken, self).__init__()
16
+ self.embedding = torch.nn.Embedding(vocab_size, 16)
17
+ self.bert = BertModel(
18
+ BertConfig(
19
+ vocab_size=vocab_size,
20
+ hidden_size=16,
21
+ num_attention_heads=2,
22
+ num_hidden_layers=2,
23
+ intermediate_size=32,
24
+ hidden_dropout_prob=0.0,
25
+ attention_probs_dropout_prob=0.0,
26
+ max_position_embeddings=3,
27
+ ),
28
+ add_pooling_layer=False,
29
+ )
30
+ self.linear = torch.nn.Linear(48, 2)
31
+
32
+ def forward(
33
+ self,
34
+ input_ids: torch.LongTensor,
35
+ labels: Optional[torch.LongTensor] = None,
36
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
37
+ bz = input_ids.shape[0]
38
+ x = self.embedding(input_ids)
39
+ bert_output = self.bert.forward(inputs_embeds=x, return_dict=True)
40
+ output = bert_output.last_hidden_state.reshape((bz, 48))
41
+ y = self.linear(output)
42
+ loss = None
43
+ if labels is not None:
44
+ loss_fct = CrossEntropyLoss()
45
+ loss = loss_fct(y, labels)
46
+ return loss, y
47
+
48
+
49
+ def create_time_token_tokenizer(simulated_data):
50
+ vocab = []
51
+ for row in simulated_data:
52
+ x1, x2, t1, t2, y = row
53
+ x1 = f"c-{x1}"
54
+ x2 = f"c-{x2}"
55
+ t = f"t-{t2 - t1}"
56
+ if x1 not in vocab:
57
+ vocab.append(x1)
58
+ if x2 not in vocab:
59
+ vocab.append(x2)
60
+ if t not in vocab:
61
+ vocab.append(t)
62
+ return {c: i + 1 for i, c in enumerate(vocab)}
63
+
64
+
65
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
+
67
+
68
+ def eval_step(simulated_data, time_token_tokenizer, time_embedding_model):
69
+ time_embedding_model.eval()
70
+ eval_input_ids = []
71
+ eval_y = []
72
+ for row in simulated_data:
73
+ x1, x2, t1, t2, y = row
74
+ x1 = f"c-{x1}"
75
+ x2 = f"c-{x2}"
76
+ t = f"t-{t2 - t1}"
77
+ eval_input_ids.append(
78
+ [
79
+ time_token_tokenizer[x1],
80
+ time_token_tokenizer[t],
81
+ time_token_tokenizer[x2],
82
+ ]
83
+ )
84
+ eval_y.append(y)
85
+ with torch.no_grad():
86
+ batched_input_ids = torch.tensor(eval_input_ids, dtype=torch.long).to(device)
87
+ batched_y = np.asarray(eval_y)
88
+ # Compute loss and forward pass
89
+ _, y_pred = time_embedding_model(batched_input_ids)
90
+ y_probs = torch.nn.functional.softmax(y_pred, dim=1)
91
+ y_probs = y_probs.detach().cpu().numpy()
92
+ roc_auc = roc_auc_score(batched_y, y_probs[:, 1])
93
+ accuracy = accuracy_score(batched_y, y_probs[:, 1] > y_probs[:, 0])
94
+ print(f"ROC AUC: {roc_auc}")
95
+ print(f"Accuracy: {accuracy}")
96
+ return accuracy, roc_auc
97
+
98
+
99
+ def train_step(
100
+ simulated_data, time_token_tokenizer, time_embedding_model, time_embedding_optimizer
101
+ ):
102
+ batched_input_ids = []
103
+ batched_y = []
104
+ indices = np.random.choice(simulated_data.shape[0], size=8, replace=False)
105
+ for row in simulated_data[indices, :]:
106
+ x1, x2, t1, t2, y = row
107
+ x1 = f"c-{x1}"
108
+ x2 = f"c-{x2}"
109
+ t = f"t-{t2 - t1}"
110
+ batched_input_ids.append(
111
+ [
112
+ time_token_tokenizer[x1],
113
+ time_token_tokenizer[t],
114
+ time_token_tokenizer[x2],
115
+ ]
116
+ )
117
+ batched_y.append(y)
118
+ batched_input_ids = torch.tensor(batched_input_ids, dtype=torch.long).to(device)
119
+ batched_y = torch.tensor(batched_y, dtype=torch.long).to(device)
120
+ # Zero the gradients
121
+ time_embedding_optimizer.zero_grad()
122
+ # Compute loss and forward pass
123
+ loss, _ = time_embedding_model(batched_input_ids, batched_y)
124
+ # Backward pass (compute gradients)
125
+ loss.backward()
126
+ # Update model parameters
127
+ time_embedding_optimizer.step()
128
+ return loss
129
+
130
+
131
+ def main(args):
132
+ simulated_data = generate_simulation_data(args.n_samples)
133
+ time_token_tokenizer = create_time_token_tokenizer(simulated_data)
134
+ time_embedding_model = ModelTimeToken(len(time_token_tokenizer) + 1).to(device)
135
+ time_embedding_optimizer = optim.Adam(time_embedding_model.parameters(), lr=0.001)
136
+ steps = []
137
+ roc_aucs = []
138
+ accuracies = []
139
+ for step in range(args.n_steps):
140
+ loss = train_step(
141
+ simulated_data,
142
+ time_token_tokenizer,
143
+ time_embedding_model,
144
+ time_embedding_optimizer,
145
+ )
146
+ print(f"Step {step}: Loss = {loss.item()}")
147
+ # Evaluation
148
+ if (
149
+ args.n_steps % args.eval_frequency == 0
150
+ and args.n_steps > args.eval_frequency
151
+ ):
152
+ accuracy, roc_auc = eval_step(
153
+ simulated_data, time_token_tokenizer, time_embedding_model
154
+ )
155
+ steps.append(step)
156
+ roc_aucs.append(roc_auc)
157
+ accuracies.append(accuracy)
158
+ return {"steps": steps, "roc_auc": roc_aucs, "accuracy": accuracies}
159
+
160
+
161
+ if __name__ == "__main__":
162
+ import argparse
163
+ import json
164
+ from pathlib import Path
165
+
166
+ parser = argparse.ArgumentParser("Model with time token simulation")
167
+ parser.add_argument("--output_dir", type=str, required=True)
168
+ parser.add_argument("--n_steps", type=int, default=10000)
169
+ parser.add_argument("--n_samples", type=int, default=1000)
170
+ parser.add_argument("--batch_size", type=int, default=128)
171
+ parser.add_argument("--eval_frequency", type=int, default=100)
172
+ args = parser.parse_args()
173
+ output_dir = Path(args.output_dir)
174
+ output_dir.mkdir(exist_ok=True, parents=True)
175
+ metrics = main(args)
176
+ with open(output_dir / "time_token_metrics.json", "w") as f:
177
+ json.dump(metrics, f)
@@ -0,0 +1,23 @@
1
+ task_name: "cabg_prediction"
2
+ outcome_events: [
3
+ "43528001",
4
+ "43528003",
5
+ "43528004",
6
+ "43528002",
7
+ "4305852",
8
+ "4168831",
9
+ "2107250",
10
+ "2107216",
11
+ "2107222",
12
+ "2107231",
13
+ "4336464",
14
+ "4231998",
15
+ "4284104",
16
+ "2100873",
17
+ ]
18
+ future_visit_start: 0
19
+ future_visit_end: -1
20
+ prediction_window_start: 0
21
+ prediction_window_end: 365
22
+ max_new_tokens: 1024
23
+ include_descendants: true
@@ -80,20 +80,9 @@ class TimeToEventModel:
80
80
  return token in self.outcome_events
81
81
 
82
82
  def simulate(
83
- self, partial_history: Union[np.ndarray, List[str]]
83
+ self,
84
+ partial_history: Union[np.ndarray, List[str]],
84
85
  ) -> List[List[str]]:
85
-
86
- sequence_is_demographics = len(partial_history) == 4 and partial_history[
87
- 0
88
- ].startswith("year")
89
- sequence_ends_ve = is_visit_end(partial_history[-1])
90
-
91
- if not (sequence_is_demographics | sequence_ends_ve):
92
- raise ValueError(
93
- "There are only two types of sequences allowed. 1) the sequence only contains "
94
- "demographics; 2) the sequence ends on VE;"
95
- )
96
-
97
86
  token_ids = self.tokenizer.encode(partial_history)
98
87
  prompt = torch.tensor(token_ids).unsqueeze(0).to(self.device)
99
88
 
@@ -118,9 +118,9 @@ def main(args):
118
118
  LOG.info(f"Top P {args.top_p}")
119
119
  LOG.info(f"Top K {args.top_k}")
120
120
 
121
- cehrgpt_model.resize_position_embeddings(
122
- cehrgpt_model.config.max_position_embeddings + task_config.max_new_tokens
123
- )
121
+ # cehrgpt_model.resize_position_embeddings(
122
+ # cehrgpt_model.config.max_position_embeddings + task_config.max_new_tokens
123
+ # )
124
124
 
125
125
  generation_config = TimeToEventModel.get_generation_config(
126
126
  tokenizer=cehrgpt_tokenizer,
@@ -190,14 +190,22 @@ def main(args):
190
190
  args.max_n_trial,
191
191
  )
192
192
  visit_counter = sum([int(is_visit_end(_)) for _ in partial_history])
193
+ predicted_boolean_probability = (
194
+ sum([event != "0" for event in concept_time_to_event.outcome_events])
195
+ / len(concept_time_to_event.outcome_events)
196
+ if concept_time_to_event
197
+ else 0.0
198
+ )
193
199
  tte_outputs.append(
194
200
  {
195
- "person_id": record["person_id"],
196
- "index_date": record["index_date"],
201
+ "subject_id": record["person_id"],
202
+ "prediction_time": record["index_date"],
197
203
  "visit_counter": visit_counter,
198
- "label": label,
204
+ "boolean_value": label,
205
+ "predicted_boolean_probability": predicted_boolean_probability,
206
+ "predicted_boolean_value": None,
199
207
  "time_to_event": time_to_event,
200
- "prediction": (
208
+ "trials": (
201
209
  asdict(concept_time_to_event) if concept_time_to_event else None
202
210
  ),
203
211
  }
@@ -263,9 +271,13 @@ def filter_out_existing_results(
263
271
  parquet_files = glob.glob(os.path.join(prediction_output_folder_name, "*parquet"))
264
272
  if parquet_files:
265
273
  cohort_members = set()
266
- results_dataframe = pd.read_parquet(parquet_files)[["person_id", "index_date"]]
274
+ results_dataframe = pd.read_parquet(parquet_files)[
275
+ ["subject_id", "prediction_time"]
276
+ ]
267
277
  for row in results_dataframe.itertuples():
268
- cohort_members.add((row.person_id, row.index_date.strftime("%Y-%m-%d")))
278
+ cohort_members.add(
279
+ (row.subject_id, row.prediction_time.strftime("%Y-%m-%d"))
280
+ )
269
281
 
270
282
  def filter_func(batched):
271
283
  return [
@@ -292,12 +304,14 @@ def flush_to_disk_if_full(
292
304
  pd.DataFrame(
293
305
  tte_outputs,
294
306
  columns=[
295
- "person_id",
296
- "index_date",
307
+ "subject_id",
308
+ "prediction_time",
297
309
  "visit_counter",
298
- "label",
310
+ "boolean_value",
311
+ "predicted_boolean_probability",
312
+ "predicted_boolean_value",
299
313
  "time_to_event",
300
- "prediction",
314
+ "trials",
301
315
  ],
302
316
  ).to_parquet(output_parquet_file)
303
317
  tte_outputs.clear()
File without changes