cehrgpt 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. __init__.py +0 -0
  2. cehrgpt/__init__.py +0 -0
  3. cehrgpt/analysis/__init__.py +0 -0
  4. cehrgpt/analysis/privacy/__init__.py +0 -0
  5. cehrgpt/analysis/privacy/attribute_inference.py +275 -0
  6. cehrgpt/analysis/privacy/attribute_inference_config.yml +8975 -0
  7. cehrgpt/analysis/privacy/member_inference.py +172 -0
  8. cehrgpt/analysis/privacy/nearest_neighbor_inference.py +189 -0
  9. cehrgpt/analysis/privacy/reid_inference.py +407 -0
  10. cehrgpt/analysis/privacy/utils.py +255 -0
  11. cehrgpt/cehrgpt_args.py +142 -0
  12. cehrgpt/data/__init__.py +0 -0
  13. cehrgpt/data/hf_cehrgpt_dataset.py +80 -0
  14. cehrgpt/data/hf_cehrgpt_dataset_collator.py +482 -0
  15. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +116 -0
  16. cehrgpt/generation/__init__.py +0 -0
  17. cehrgpt/generation/chatgpt_generation.py +106 -0
  18. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +333 -0
  19. cehrgpt/generation/omop_converter_batch.py +644 -0
  20. cehrgpt/generation/omop_entity.py +515 -0
  21. cehrgpt/gpt_utils.py +331 -0
  22. cehrgpt/models/__init__.py +0 -0
  23. cehrgpt/models/config.py +205 -0
  24. cehrgpt/models/hf_cehrgpt.py +1817 -0
  25. cehrgpt/models/hf_modeling_outputs.py +158 -0
  26. cehrgpt/models/pretrained_embeddings.py +82 -0
  27. cehrgpt/models/special_tokens.py +30 -0
  28. cehrgpt/models/tokenization_hf_cehrgpt.py +1077 -0
  29. cehrgpt/omop/__init__.py +0 -0
  30. cehrgpt/omop/condition_era.py +20 -0
  31. cehrgpt/omop/observation_period.py +43 -0
  32. cehrgpt/omop/omop_argparse.py +38 -0
  33. cehrgpt/omop/omop_table_builder.py +86 -0
  34. cehrgpt/omop/queries/__init__.py +0 -0
  35. cehrgpt/omop/queries/condition_era.py +86 -0
  36. cehrgpt/omop/queries/observation_period.py +135 -0
  37. cehrgpt/omop/sample_omop_tables.py +71 -0
  38. cehrgpt/runners/__init__.py +0 -0
  39. cehrgpt/runners/gpt_runner_util.py +99 -0
  40. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +746 -0
  41. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +370 -0
  42. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +137 -0
  43. cehrgpt/runners/hyperparameter_search_util.py +223 -0
  44. cehrgpt/time_to_event/__init__.py +0 -0
  45. cehrgpt/time_to_event/config/30_day_readmission.yaml +8 -0
  46. cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +8 -0
  47. cehrgpt/time_to_event/config/t2dm_hf.yaml +8 -0
  48. cehrgpt/time_to_event/time_to_event_model.py +226 -0
  49. cehrgpt/time_to_event/time_to_event_prediction.py +347 -0
  50. cehrgpt/time_to_event/time_to_event_utils.py +55 -0
  51. cehrgpt/tools/__init__.py +0 -0
  52. cehrgpt/tools/ehrshot_benchmark.py +74 -0
  53. cehrgpt/tools/generate_pretrained_embeddings.py +130 -0
  54. cehrgpt/tools/merge_synthetic_real_dataasets.py +218 -0
  55. cehrgpt/tools/upload_omop_tables.py +108 -0
  56. cehrgpt-0.0.1.dist-info/LICENSE +21 -0
  57. cehrgpt-0.0.1.dist-info/METADATA +66 -0
  58. cehrgpt-0.0.1.dist-info/RECORD +60 -0
  59. cehrgpt-0.0.1.dist-info/WHEEL +5 -0
  60. cehrgpt-0.0.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,223 @@
1
+ from functools import partial
2
+ from typing import Callable, Tuple
3
+
4
+ import optuna
5
+ from cehrbert.runners.hf_runner_argument_dataclass import ModelArguments
6
+ from datasets import Dataset, DatasetDict
7
+ from transformers import (
8
+ EarlyStoppingCallback,
9
+ Trainer,
10
+ TrainerCallback,
11
+ TrainingArguments,
12
+ )
13
+ from transformers.utils import logging
14
+
15
+ from cehrgpt.data.hf_cehrgpt_dataset_collator import CehrGptDataCollator
16
+ from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
17
+
18
+ LOG = logging.get_logger("transformers")
19
+
20
+
21
+ class OptunaMetricCallback(TrainerCallback):
22
+ """
23
+ A custom callback to store the best metric in the evaluation metrics dictionary during training.
24
+
25
+ This callback monitors the training state and updates the metrics dictionary with the `best_metric`
26
+ (e.g., the lowest `eval_loss` or highest accuracy) observed during training. It ensures that the
27
+ best metric value is preserved in the final evaluation results, even if early stopping occurs.
28
+
29
+ Attributes:
30
+ None
31
+
32
+ Methods:
33
+ on_evaluate(args, state, control, **kwargs):
34
+ Called during evaluation. Adds `state.best_metric` to `metrics` if it exists.
35
+
36
+ Example Usage:
37
+ ```
38
+ store_best_metric_callback = StoreBestMetricCallback()
39
+ trainer = Trainer(
40
+ model=model,
41
+ args=training_args,
42
+ train_dataset=train_dataset,
43
+ eval_dataset=val_dataset,
44
+ callbacks=[store_best_metric_callback]
45
+ )
46
+ ```
47
+ """
48
+
49
+ def on_evaluate(self, args, state, control, **kwargs):
50
+ """
51
+ During evaluation, adds the best metric value to the metrics dictionary if it exists.
52
+
53
+ Args:
54
+ args: Training arguments.
55
+ state: Trainer state object that holds information about training progress.
56
+ control: Trainer control object to modify training behavior.
57
+ **kwargs: Additional keyword arguments, including `metrics`, which holds evaluation metrics.
58
+
59
+ Updates:
60
+ `metrics["best_metric"]`: Sets this to `state.best_metric` if available.
61
+ """
62
+ # Check if best metric is available and add it to metrics if it exists
63
+ metrics = kwargs.get("metrics", {})
64
+ if state.best_metric is not None:
65
+ metrics.update(
66
+ {"optuna_best_metric": min(state.best_metric, metrics["eval_loss"])}
67
+ )
68
+ else:
69
+ metrics.update({"optuna_best_metric": metrics["eval_loss"]})
70
+
71
+
72
+ # Define the hyperparameter search space with parameters
73
+ def hp_space(
74
+ trial: optuna.Trial,
75
+ lr_range: Tuple[float, float] = (1e-5, 5e-5),
76
+ batch_sizes=None,
77
+ weight_decays: Tuple[float, float] = (1e-4, 1e-2),
78
+ num_train_epochs: Tuple[float, ...] = 10,
79
+ ):
80
+ if batch_sizes is None:
81
+ batch_sizes = [4, 8]
82
+ return {
83
+ "learning_rate": trial.suggest_float("learning_rate", *lr_range, log=True),
84
+ "per_device_train_batch_size": trial.suggest_categorical(
85
+ "per_device_train_batch_size", batch_sizes
86
+ ),
87
+ "weight_decay": trial.suggest_float("weight_decay", *weight_decays, log=True),
88
+ "num_train_epochs": trial.suggest_int("num_train_epochs", *num_train_epochs),
89
+ }
90
+
91
+
92
+ def sample_dataset(data: Dataset, percentage: float, seed: int) -> Dataset:
93
+ """
94
+ Samples a subset of the given dataset based on a specified percentage.
95
+
96
+ This function uses a random train-test split to select a subset of the dataset, returning a sample
97
+ that is approximately `percentage` of the total dataset size. It is useful for creating smaller
98
+ datasets for tasks such as hyperparameter tuning or quick testing.
99
+
100
+ Args:
101
+ data (Dataset): The input dataset to sample from.
102
+ percentage (float): The fraction of the dataset to sample, represented as a decimal
103
+ (e.g., 0.1 for 10%).
104
+ seed (int): A random seed for reproducibility in the sampling process.
105
+
106
+ Returns:
107
+ Dataset: A sampled subset of the input dataset containing `percentage` of the original data.
108
+
109
+ Example:
110
+ ```
111
+ sampled_data = sample_dataset(my_dataset, percentage=0.1, seed=42)
112
+ ```
113
+
114
+ Notes:
115
+ - The `train_test_split` method splits the dataset into "train" and "test" portions. This function
116
+ returns the "test" portion, which is the specified percentage of the dataset.
117
+ - Ensure that `percentage` is between 0 and 1 to avoid errors.
118
+ """
119
+ if percentage == 1.0:
120
+ return data
121
+
122
+ return data.train_test_split(
123
+ test_size=percentage,
124
+ seed=seed,
125
+ )["test"]
126
+
127
+
128
+ def perform_hyperparameter_search(
129
+ model_init: Callable,
130
+ dataset: DatasetDict,
131
+ data_collator: CehrGptDataCollator,
132
+ training_args: TrainingArguments,
133
+ model_args: ModelArguments,
134
+ cehrgpt_args: CehrGPTArguments,
135
+ ) -> TrainingArguments:
136
+ """
137
+ Perform hyperparameter tuning for the CehrGPT model using Optuna with the Hugging Face Trainer.
138
+
139
+ This function initializes a Trainer with sampled training and validation sets, and performs
140
+ a hyperparameter search using Optuna. The search tunes learning rate, batch size, and weight decay
141
+ to optimize model performance based on a specified objective metric (e.g., validation loss).
142
+ After the search, it updates the provided `TrainingArguments` with the best hyperparameters found.
143
+
144
+ Args:
145
+ model_init (Callable): A function to initialize the model, used for each hyperparameter trial.
146
+ dataset (DatasetDict): A Hugging Face DatasetDict containing "train" and "validation" datasets.
147
+ data_collator (CehrGptDataCollator): A data collator for processing batches.
148
+ training_args (TrainingArguments): Configuration for training parameters (e.g., epochs, evaluation strategy).
149
+ model_args (ModelArguments): Model configuration arguments, including early stopping parameters.
150
+ cehrgpt_args (CehrGPTArguments): Additional arguments specific to CehrGPT, including hyperparameter
151
+ tuning options such as learning rate range, batch sizes, and tuning percentage.
152
+
153
+ Returns:
154
+ TrainingArguments: Updated `TrainingArguments` instance containing the best hyperparameters found
155
+ from the search.
156
+
157
+ Example:
158
+ ```
159
+ best_training_args = perform_hyperparameter_search(
160
+ model_init=my_model_init,
161
+ dataset=my_dataset_dict,
162
+ data_collator=my_data_collator,
163
+ training_args=initial_training_args,
164
+ model_args=model_args,
165
+ cehrgpt_args=cehrgpt_args
166
+ )
167
+ ```
168
+
169
+ Notes:
170
+ - If `cehrgpt_args.hyperparameter_tuning` is set to `True`, this function samples a portion of the
171
+ training and validation datasets for efficient tuning.
172
+ - `EarlyStoppingCallback` is added to the Trainer if early stopping is enabled in `model_args`.
173
+ - Optuna's `hyperparameter_search` is configured with the specified number of trials (`n_trials`)
174
+ and learning rate and batch size ranges provided in `cehrgpt_args`.
175
+
176
+ Logging:
177
+ Logs the best hyperparameters found at the end of the search.
178
+ """
179
+ if cehrgpt_args.hyperparameter_tuning:
180
+ sampled_train = sample_dataset(
181
+ dataset["train"],
182
+ cehrgpt_args.hyperparameter_tuning_percentage,
183
+ training_args.seed,
184
+ )
185
+ sampled_val = sample_dataset(
186
+ dataset["validation"],
187
+ cehrgpt_args.hyperparameter_tuning_percentage,
188
+ training_args.seed,
189
+ )
190
+ hyperparam_trainer = Trainer(
191
+ model_init=model_init,
192
+ data_collator=data_collator,
193
+ train_dataset=sampled_train,
194
+ eval_dataset=sampled_val,
195
+ callbacks=[
196
+ EarlyStoppingCallback(model_args.early_stopping_patience),
197
+ OptunaMetricCallback(),
198
+ ],
199
+ args=training_args,
200
+ )
201
+ # Perform hyperparameter search
202
+ best_trial = hyperparam_trainer.hyperparameter_search(
203
+ direction="minimize",
204
+ hp_space=partial(
205
+ hp_space,
206
+ lr_range=(cehrgpt_args.lr_low, cehrgpt_args.lr_high),
207
+ weight_decays=(
208
+ cehrgpt_args.weight_decays_low,
209
+ cehrgpt_args.weight_decays_high,
210
+ ),
211
+ batch_sizes=cehrgpt_args.hyperparameter_batch_sizes,
212
+ num_train_epochs=cehrgpt_args.hyperparameter_num_train_epochs,
213
+ ),
214
+ backend="optuna",
215
+ n_trials=cehrgpt_args.n_trials,
216
+ compute_objective=lambda m: m["optuna_best_metric"],
217
+ )
218
+ LOG.info("Best hyperparameters: %s", best_trial.hyperparameters)
219
+ # Update training arguments with best hyperparameters and set epochs based on adjusted effective epochs
220
+ for k, v in best_trial.hyperparameters.items():
221
+ setattr(training_args, k, v)
222
+
223
+ return training_args
File without changes
@@ -0,0 +1,8 @@
1
+ task_name: "30_day_readmission_prediction"
2
+ outcome_events: ["9201", "262", "8971", "8920"]
3
+ include_descendants: false
4
+ future_visit_start: 0
5
+ future_visit_end: -1
6
+ prediction_window_start: 0
7
+ prediction_window_end: 30
8
+ max_new_tokens: 128
@@ -0,0 +1,8 @@
1
+ task_name: "next_visit_type_prediction"
2
+ outcome_events: [
3
+ '9202', '9203', '581477', '9201', '5083', '262', '38004250', '8883', '38004238', '38004251',
4
+ '38004222', '38004268', '38004228', '32693', '8971', '38004269', '38004193', '32036', '8782'
5
+ ]
6
+ include_descendants: false
7
+ future_visit_start: 0
8
+ future_visit_end: 1
@@ -0,0 +1,8 @@
1
+ task_name: "t2dm_hf_prediction"
2
+ outcome_events: ["316139"]
3
+ future_visit_start: 0
4
+ future_visit_end: -1
5
+ prediction_window_start: 30
6
+ prediction_window_end: -1
7
+ max_new_tokens: 512
8
+ include_descendants: true
@@ -0,0 +1,226 @@
1
+ import math
2
+ from collections import Counter
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from cehrbert_data.decorators.patient_event_decorator_base import time_month_token
9
+ from transformers import GenerationConfig
10
+
11
+ from cehrgpt.gpt_utils import (
12
+ extract_time_interval_in_days,
13
+ is_att_token,
14
+ is_visit_end,
15
+ is_visit_start,
16
+ )
17
+ from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
18
+ from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
19
+
20
+
21
+ @dataclass
22
+ class TimeToEvent:
23
+ average_time: float
24
+ median_time: float
25
+ standard_deviation: float
26
+ most_likely_time: str
27
+ num_of_simulations: int
28
+ time_intervals: List[int]
29
+ outcome_events: List[str]
30
+ time_interval_probability_table: List[Dict[str, Any]]
31
+
32
+
33
+ def create_time_to_event(
34
+ time_event_tuples: List[Tuple[str, int]], num_of_simulations: int
35
+ ) -> TimeToEvent:
36
+ outcome_events, time_intervals = zip(*time_event_tuples)
37
+ time_buckets = [time_month_token(_) for _ in time_intervals]
38
+ time_bucket_counter = Counter(time_buckets)
39
+ most_common_item = time_bucket_counter.most_common(1)[0][0]
40
+ total_count = sum(time_bucket_counter.values())
41
+ # Generate the probability table
42
+ probability_table = {
43
+ item: count / total_count for item, count in time_bucket_counter.items()
44
+ }
45
+ sorted_probability_table = [
46
+ {"time_interval": k, "probability": v}
47
+ for k, v in sorted(probability_table.items(), key=lambda x: x[1], reverse=True)
48
+ ]
49
+ return TimeToEvent(
50
+ time_intervals=time_intervals,
51
+ outcome_events=outcome_events,
52
+ average_time=np.mean(time_intervals),
53
+ median_time=np.median(time_intervals),
54
+ standard_deviation=np.std(time_intervals),
55
+ most_likely_time=most_common_item,
56
+ num_of_simulations=num_of_simulations,
57
+ time_interval_probability_table=sorted_probability_table,
58
+ )
59
+
60
+
61
+ class TimeToEventModel:
62
+ def __init__(
63
+ self,
64
+ tokenizer: CehrGptTokenizer,
65
+ model: CEHRGPT2LMHeadModel,
66
+ outcome_events: List[str],
67
+ generation_config: GenerationConfig,
68
+ device: torch.device = torch.device("cpu"),
69
+ batch_size: int = 32,
70
+ ):
71
+ self.tokenizer = tokenizer
72
+ self.model = model.eval()
73
+ self.generation_config = generation_config
74
+ self.outcome_events = outcome_events
75
+ self.device = device
76
+ self.batch_size = batch_size
77
+ self.max_sequence = model.config.n_positions
78
+
79
+ def is_outcome_event(self, token: str):
80
+ return token in self.outcome_events
81
+
82
+ def simulate(
83
+ self, partial_history: Union[np.ndarray, List[str]]
84
+ ) -> List[List[str]]:
85
+
86
+ sequence_is_demographics = len(partial_history) == 4 and partial_history[
87
+ 0
88
+ ].startswith("year")
89
+ sequence_ends_ve = is_visit_end(partial_history[-1])
90
+
91
+ if not (sequence_is_demographics | sequence_ends_ve):
92
+ raise ValueError(
93
+ "There are only two types of sequences allowed. 1) the sequence only contains "
94
+ "demographics; 2) the sequence ends on VE;"
95
+ )
96
+
97
+ token_ids = self.tokenizer.encode(partial_history)
98
+ prompt = torch.tensor(token_ids).unsqueeze(0).to(self.device)
99
+
100
+ simulated_sequences = []
101
+ num_iters = max(
102
+ math.ceil(self.generation_config.num_return_sequences / self.batch_size), 1
103
+ )
104
+ old_num_return_sequences = self.generation_config.num_return_sequences
105
+ self.generation_config.num_return_sequences = min(
106
+ self.batch_size, old_num_return_sequences
107
+ )
108
+ with torch.no_grad():
109
+ for _ in range(num_iters):
110
+ results = self.model.generate(
111
+ inputs=prompt,
112
+ generation_config=self.generation_config,
113
+ )
114
+ # Clear the cache
115
+ torch.cuda.empty_cache()
116
+ # Add the sequences to the result array
117
+ simulated_sequences.extend(
118
+ [
119
+ self.tokenizer.decode(seq.cpu().numpy())
120
+ for seq in results.sequences
121
+ ]
122
+ )
123
+
124
+ self.generation_config.num_return_sequences = old_num_return_sequences
125
+ return simulated_sequences
126
+
127
+ def predict_time_to_events(
128
+ self,
129
+ partial_history: Union[np.ndarray, list],
130
+ future_visit_start: int = 0,
131
+ future_visit_end: int = -1,
132
+ prediction_window_start: int = 0,
133
+ prediction_window_end: int = 365,
134
+ debug: bool = False,
135
+ max_n_trial: int = 2,
136
+ ) -> Optional[TimeToEvent]:
137
+ patient_history_length = len(partial_history)
138
+ time_event_tuples = []
139
+ seqs_failed_to_convert = []
140
+ n_trial = 0
141
+ num_return_sequences = self.generation_config.num_return_sequences
142
+ max_new_tokens = self.generation_config.max_new_tokens
143
+ while (
144
+ len(time_event_tuples) < self.generation_config.num_return_sequences
145
+ and n_trial < max_n_trial
146
+ ):
147
+ self.generation_config.num_return_sequences = num_return_sequences - len(
148
+ time_event_tuples
149
+ )
150
+ # self.generation_config.max_new_tokens = max_new_tokens * (n_trial + 1)
151
+ simulated_seqs = self.simulate(partial_history)
152
+ n_trial += 1
153
+ for seq in simulated_seqs:
154
+ visit_counter = 0
155
+ time_delta = 0
156
+ success = False
157
+ for next_token in seq[patient_history_length:]:
158
+ visit_counter += int(is_visit_start(next_token))
159
+ if (
160
+ visit_counter > future_visit_end != -1
161
+ or time_delta > prediction_window_end != -1
162
+ ):
163
+ time_event_tuples.append(("0", time_delta))
164
+ success = True
165
+ break
166
+ if is_att_token(next_token):
167
+ time_delta += extract_time_interval_in_days(next_token)
168
+ elif (
169
+ visit_counter >= future_visit_start
170
+ and time_delta >= prediction_window_start
171
+ ) and self.is_outcome_event(next_token):
172
+ time_event_tuples.append((next_token, time_delta))
173
+ success = True
174
+ break
175
+ if not success:
176
+ # This indicates the generated sequence did not satisfy the criteria
177
+ if future_visit_end != -1 or prediction_window_end != -1:
178
+ seqs_failed_to_convert.append(seq[patient_history_length:])
179
+ else:
180
+ time_event_tuples.append(("0", time_delta))
181
+
182
+ self.generation_config.num_return_sequences = num_return_sequences
183
+ self.generation_config.max_new_tokens = max_new_tokens
184
+
185
+ if debug:
186
+ print(f"seqs_failed_to_convert: {seqs_failed_to_convert}")
187
+
188
+ # Count the occurrences of each time tokens for each concept
189
+ return (
190
+ create_time_to_event(time_event_tuples, len(time_event_tuples))
191
+ if len(time_event_tuples) > 0
192
+ else None
193
+ )
194
+
195
+ @staticmethod
196
+ def get_generation_config(
197
+ tokenizer: CehrGptTokenizer,
198
+ max_length: int,
199
+ num_return_sequences: int,
200
+ top_p: float = 1.0,
201
+ top_k: int = 300,
202
+ temperature: float = 1.0,
203
+ repetition_penalty: float = 1.0,
204
+ epsilon_cutoff: float = 0.0,
205
+ max_new_tokens: int = 128,
206
+ ) -> GenerationConfig:
207
+ return GenerationConfig(
208
+ max_length=max_length,
209
+ max_new_tokens=max_new_tokens,
210
+ num_return_sequences=num_return_sequences,
211
+ temperature=temperature,
212
+ repetition_penalty=repetition_penalty,
213
+ epsilon_cutoff=epsilon_cutoff,
214
+ top_p=top_p,
215
+ top_k=top_k,
216
+ bos_token_id=tokenizer.end_token_id,
217
+ eos_token_id=tokenizer.end_token_id,
218
+ pad_token_id=tokenizer.pad_token_id,
219
+ do_sample=True,
220
+ use_cache=True,
221
+ return_dict_in_generate=True,
222
+ output_attentions=False,
223
+ output_hidden_states=False,
224
+ output_scores=False,
225
+ renormalize_logits=True,
226
+ )