cehrgpt 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/htn_treatment_pathway.py +546 -0
- cehrgpt/analysis/treatment_pathway/__init__.py +0 -0
- cehrgpt/analysis/treatment_pathway/depression_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/diabetes_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/htn_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/treatment_pathway.py +631 -0
- cehrgpt/data/cehrgpt_data_processor.py +549 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +4 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +285 -652
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +38 -5
- cehrgpt/generation/cehrgpt_conditional_generation.py +2 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +20 -12
- cehrgpt/generation/omop_converter_batch.py +11 -4
- cehrgpt/gpt_utils.py +73 -3
- cehrgpt/models/activations.py +27 -0
- cehrgpt/models/config.py +6 -2
- cehrgpt/models/gpt2.py +560 -0
- cehrgpt/models/hf_cehrgpt.py +183 -460
- cehrgpt/models/tokenization_hf_cehrgpt.py +380 -50
- cehrgpt/omop/ontology.py +154 -0
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +24 -78
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +48 -44
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +46 -34
- cehrgpt/runners/hyperparameter_search_util.py +180 -69
- cehrgpt/runners/sample_packing_trainer.py +11 -2
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +8 -2
- cehrgpt-0.1.3.dist-info/METADATA +238 -0
- {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.3.dist-info}/RECORD +32 -22
- cehrgpt-0.1.2.dist-info/METADATA +0 -209
- /cehrgpt/tools/{merge_synthetic_real_dataasets.py → merge_synthetic_real_datasets.py} +0 -0
- {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.3.dist-info}/WHEEL +0 -0
- {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.3.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
|
|
1
1
|
from functools import partial
|
2
|
-
from typing import Callable, Tuple
|
2
|
+
from typing import Callable, List, Optional, Tuple, Union
|
3
3
|
|
4
4
|
import optuna
|
5
5
|
from cehrbert.runners.hf_runner_argument_dataclass import ModelArguments
|
@@ -64,28 +64,99 @@ class OptunaMetricCallback(TrainerCallback):
|
|
64
64
|
metrics.update({"optuna_best_metric": metrics["eval_loss"]})
|
65
65
|
|
66
66
|
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
67
|
+
def get_suggestion(
|
68
|
+
trial,
|
69
|
+
hyperparameter_name: str,
|
70
|
+
hyperparameters: List[Union[float, int]],
|
71
|
+
is_grid: bool = False,
|
72
|
+
) -> Union[float, int]:
|
73
|
+
"""
|
74
|
+
Get hyperparameter suggestion based on search mode.
|
75
|
+
|
76
|
+
Args:
|
77
|
+
trial: Optuna trial object
|
78
|
+
hyperparameter_name: Name of the hyperparameter
|
79
|
+
hyperparameters: List of hyperparameter values
|
80
|
+
is_grid: Whether to use grid search mode
|
81
|
+
|
82
|
+
Returns:
|
83
|
+
Suggested hyperparameter value
|
84
|
+
|
85
|
+
Raises:
|
86
|
+
RuntimeError: If Bayesian mode is used with incorrect number of bounds
|
87
|
+
"""
|
88
|
+
if is_grid:
|
89
|
+
return trial.suggest_categorical(hyperparameter_name, hyperparameters)
|
90
|
+
|
91
|
+
# For Bayesian optimization, we need exactly 2 values (lower and upper bounds)
|
92
|
+
if len(hyperparameters) != 2:
|
93
|
+
raise RuntimeError(
|
94
|
+
f"{hyperparameter_name} must contain exactly two values (lower and upper bound) "
|
95
|
+
f"for Bayesian Optimization, but {len(hyperparameters)} values were provided: {hyperparameters}"
|
96
|
+
)
|
97
|
+
|
98
|
+
# Ensure bounds are sorted
|
99
|
+
lower, upper = sorted(hyperparameters)
|
100
|
+
return trial.suggest_float(hyperparameter_name, lower, upper, log=True)
|
101
|
+
|
102
|
+
|
103
|
+
def hp_space(trial: optuna.Trial, cehrgpt_args: CehrGPTArguments):
|
104
|
+
"""
|
105
|
+
Define the hyperparameter search space.
|
106
|
+
|
107
|
+
Args:
|
108
|
+
trial: Optuna trial object
|
109
|
+
cehrgpt_args: CehrGPTArguments
|
110
|
+
Returns:
|
111
|
+
Dictionary of hyperparameter suggestions
|
112
|
+
"""
|
113
|
+
|
114
|
+
is_grid = cehrgpt_args.hyperparameter_tuning_is_grid
|
115
|
+
learning_rates = cehrgpt_args.hyperparameter_learning_rates
|
116
|
+
weight_decays = cehrgpt_args.hyperparameter_weight_decays
|
117
|
+
batch_sizes = cehrgpt_args.hyperparameter_batch_sizes
|
118
|
+
num_train_epochs = cehrgpt_args.hyperparameter_num_train_epochs
|
119
|
+
|
77
120
|
return {
|
78
|
-
"learning_rate":
|
121
|
+
"learning_rate": get_suggestion(
|
122
|
+
trial, "learning_rate", learning_rates, is_grid
|
123
|
+
),
|
79
124
|
"per_device_train_batch_size": trial.suggest_categorical(
|
80
125
|
"per_device_train_batch_size", batch_sizes
|
81
126
|
),
|
82
|
-
"weight_decay": trial
|
127
|
+
"weight_decay": get_suggestion(trial, "weight_decay", weight_decays, is_grid),
|
83
128
|
"num_train_epochs": trial.suggest_categorical(
|
84
129
|
"num_train_epochs", num_train_epochs
|
85
130
|
),
|
86
131
|
}
|
87
132
|
|
88
133
|
|
134
|
+
def create_grid_search_space(cehrgpt_args: CehrGPTArguments):
|
135
|
+
"""
|
136
|
+
Create the search space dictionary for GridSampler.
|
137
|
+
|
138
|
+
Args:
|
139
|
+
cehrgpt_args: CehrGPTArguments
|
140
|
+
|
141
|
+
Returns:
|
142
|
+
Dictionary defining the grid search space
|
143
|
+
"""
|
144
|
+
return {
|
145
|
+
"learning_rate": cehrgpt_args.hyperparameter_learning_rates,
|
146
|
+
"weight_decay": cehrgpt_args.hyperparameter_weight_decays,
|
147
|
+
"per_device_train_batch_size": cehrgpt_args.hyperparameter_batch_sizes,
|
148
|
+
"num_train_epochs": cehrgpt_args.hyperparameter_num_train_epochs,
|
149
|
+
}
|
150
|
+
|
151
|
+
|
152
|
+
def calculate_total_combinations(search_space: dict) -> int:
|
153
|
+
"""Calculate total number of combinations in grid search."""
|
154
|
+
total = 1
|
155
|
+
for values in search_space.values():
|
156
|
+
total *= len(values)
|
157
|
+
return total
|
158
|
+
|
159
|
+
|
89
160
|
def sample_dataset(data: Dataset, percentage: float, seed: int) -> Dataset:
|
90
161
|
"""
|
91
162
|
Samples a subset of the given dataset based on a specified percentage.
|
@@ -113,7 +184,7 @@ def sample_dataset(data: Dataset, percentage: float, seed: int) -> Dataset:
|
|
113
184
|
returns the "test" portion, which is the specified percentage of the dataset.
|
114
185
|
- Ensure that `percentage` is between 0 and 1 to avoid errors.
|
115
186
|
"""
|
116
|
-
if percentage
|
187
|
+
if percentage >= 1.0:
|
117
188
|
return data
|
118
189
|
|
119
190
|
return data.train_test_split(
|
@@ -130,14 +201,13 @@ def perform_hyperparameter_search(
|
|
130
201
|
training_args: TrainingArguments,
|
131
202
|
model_args: ModelArguments,
|
132
203
|
cehrgpt_args: CehrGPTArguments,
|
133
|
-
) -> TrainingArguments:
|
204
|
+
) -> Tuple[TrainingArguments, Optional[str]]:
|
134
205
|
"""
|
135
206
|
Perform hyperparameter tuning for the CehrGPT model using Optuna with the Hugging Face Trainer.
|
136
207
|
|
137
|
-
This function
|
138
|
-
|
139
|
-
|
140
|
-
After the search, it updates the provided `TrainingArguments` with the best hyperparameters found.
|
208
|
+
This function supports two modes:
|
209
|
+
1. Bayesian Optimization (TPE): Intelligently explores hyperparameter space using bounds
|
210
|
+
2. Grid Search: Exhaustively tests all combinations of discrete values
|
141
211
|
|
142
212
|
Args:
|
143
213
|
trainer_class: A Trainer or its subclass
|
@@ -147,15 +217,15 @@ def perform_hyperparameter_search(
|
|
147
217
|
training_args (TrainingArguments): Configuration for training parameters (e.g., epochs, evaluation strategy).
|
148
218
|
model_args (ModelArguments): Model configuration arguments, including early stopping parameters.
|
149
219
|
cehrgpt_args (CehrGPTArguments): Additional arguments specific to CehrGPT, including hyperparameter
|
150
|
-
tuning options
|
220
|
+
tuning options and search mode configuration.
|
151
221
|
|
152
222
|
Returns:
|
153
|
-
TrainingArguments: Updated
|
154
|
-
|
223
|
+
Tuple[TrainingArguments, Optional[str]]: Updated TrainingArguments with best hyperparameters
|
224
|
+
and optional run_id of the best trial.
|
155
225
|
|
156
226
|
Example:
|
157
227
|
```
|
158
|
-
best_training_args = perform_hyperparameter_search(
|
228
|
+
best_training_args, run_id = perform_hyperparameter_search(
|
159
229
|
trainer_class=Trainer,
|
160
230
|
model_init=my_model_init,
|
161
231
|
dataset=my_dataset_dict,
|
@@ -176,50 +246,91 @@ def perform_hyperparameter_search(
|
|
176
246
|
Logging:
|
177
247
|
Logs the best hyperparameters found at the end of the search.
|
178
248
|
"""
|
179
|
-
if cehrgpt_args.hyperparameter_tuning:
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
#
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
249
|
+
if not cehrgpt_args.hyperparameter_tuning:
|
250
|
+
return training_args, None
|
251
|
+
|
252
|
+
# Prepare hyperparameters based on mode
|
253
|
+
if (
|
254
|
+
cehrgpt_args.hyperparameter_tuning_is_grid
|
255
|
+
and cehrgpt_args.hyperparameter_tuning_is_grid
|
256
|
+
):
|
257
|
+
search_space = create_grid_search_space(cehrgpt_args)
|
258
|
+
total_combinations = calculate_total_combinations(search_space)
|
259
|
+
|
260
|
+
LOG.info(f"Grid search mode: Testing {total_combinations} combinations")
|
261
|
+
LOG.info(f"Search space: {search_space}")
|
262
|
+
|
263
|
+
# Adjust n_trials for grid search if not set appropriately
|
264
|
+
if cehrgpt_args.n_trials < total_combinations:
|
265
|
+
LOG.warning(
|
266
|
+
f"n_trials ({cehrgpt_args.n_trials}) is less than total combinations ({total_combinations}). "
|
267
|
+
f"Setting n_trials to {total_combinations} to test all combinations."
|
268
|
+
)
|
269
|
+
cehrgpt_args.n_trials = total_combinations
|
270
|
+
|
271
|
+
# Configure sampler based on search mode
|
272
|
+
sampler = optuna.samplers.GridSampler(search_space, seed=training_args.seed)
|
273
|
+
else:
|
274
|
+
LOG.info("Bayesian optimization mode (TPE)")
|
275
|
+
LOG.info(f"Learning rate bounds: {cehrgpt_args.hyperparameter_learning_rates}")
|
276
|
+
LOG.info(f"Weight decay bounds: {cehrgpt_args.hyperparameter_weight_decays}")
|
277
|
+
LOG.info(f"Batch sizes: {cehrgpt_args.hyperparameter_batch_sizes}")
|
278
|
+
LOG.info(f"Epochs: {cehrgpt_args.hyperparameter_num_train_epochs}")
|
279
|
+
# Configure the TPE sampler
|
280
|
+
sampler = optuna.samplers.TPESampler(seed=training_args.seed)
|
281
|
+
|
282
|
+
# Prepare datasets
|
283
|
+
save_total_limit_original = training_args.save_total_limit
|
284
|
+
training_args.save_total_limit = 1
|
285
|
+
|
286
|
+
sampled_train = sample_dataset(
|
287
|
+
dataset["train"],
|
288
|
+
cehrgpt_args.hyperparameter_tuning_percentage,
|
289
|
+
training_args.seed,
|
290
|
+
)
|
291
|
+
sampled_val = sample_dataset(
|
292
|
+
dataset["validation"],
|
293
|
+
cehrgpt_args.hyperparameter_tuning_percentage,
|
294
|
+
training_args.seed,
|
295
|
+
)
|
296
|
+
# Create trainer
|
297
|
+
hyperparam_trainer = trainer_class(
|
298
|
+
model_init=model_init,
|
299
|
+
data_collator=data_collator,
|
300
|
+
train_dataset=sampled_train,
|
301
|
+
eval_dataset=sampled_val,
|
302
|
+
callbacks=[
|
303
|
+
EarlyStoppingCallback(model_args.early_stopping_patience),
|
304
|
+
OptunaMetricCallback(),
|
305
|
+
],
|
306
|
+
args=training_args,
|
307
|
+
)
|
308
|
+
|
309
|
+
best_trial = hyperparam_trainer.hyperparameter_search(
|
310
|
+
direction="minimize",
|
311
|
+
hp_space=partial(
|
312
|
+
hp_space,
|
313
|
+
cehrgpt_args=cehrgpt_args,
|
314
|
+
),
|
315
|
+
backend="optuna",
|
316
|
+
n_trials=cehrgpt_args.n_trials,
|
317
|
+
compute_objective=lambda m: m["optuna_best_metric"],
|
318
|
+
sampler=sampler,
|
319
|
+
)
|
320
|
+
|
321
|
+
# Log results
|
322
|
+
LOG.info("=" * 50)
|
323
|
+
LOG.info("HYPERPARAMETER SEARCH COMPLETED")
|
324
|
+
LOG.info("=" * 50)
|
325
|
+
LOG.info(f"Best hyperparameters: {best_trial.hyperparameters}")
|
326
|
+
LOG.info(f"Best metric (eval_loss): {best_trial.objective}")
|
327
|
+
LOG.info(f"Best run_id: {best_trial.run_id}")
|
328
|
+
LOG.info("=" * 50)
|
329
|
+
|
330
|
+
# Restore original settings and update with best hyperparameters
|
331
|
+
training_args.save_total_limit = save_total_limit_original
|
332
|
+
for k, v in best_trial.hyperparameters.items():
|
333
|
+
setattr(training_args, k, v)
|
334
|
+
LOG.info(f"Updated training_args.{k} = {v}")
|
224
335
|
|
225
|
-
return training_args
|
336
|
+
return training_args, best_trial.run_id
|
@@ -1,9 +1,10 @@
|
|
1
1
|
from typing import Optional, Union
|
2
2
|
|
3
|
+
import torch
|
3
4
|
from datasets import Dataset
|
4
5
|
from torch.utils.data import DataLoader
|
5
6
|
from transformers import Trainer
|
6
|
-
from transformers.trainer_utils import has_length
|
7
|
+
from transformers.trainer_utils import has_length, seed_worker
|
7
8
|
from transformers.utils import import_utils, logging
|
8
9
|
|
9
10
|
from cehrgpt.data.sample_packing_sampler import SamplePackingBatchSampler
|
@@ -62,7 +63,10 @@ class SamplePackingTrainer(Trainer):
|
|
62
63
|
if "num_of_concepts" in train_dataset.column_names:
|
63
64
|
lengths = train_dataset["num_of_concepts"]
|
64
65
|
else:
|
65
|
-
lengths = [
|
66
|
+
lengths = [
|
67
|
+
len(sample["input_ids"])
|
68
|
+
for sample in train_dataset.select_columns("input_ids")
|
69
|
+
]
|
66
70
|
|
67
71
|
LOG.info("Finished computing lengths for the train dataset")
|
68
72
|
else:
|
@@ -102,6 +106,11 @@ class SamplePackingTrainer(Trainer):
|
|
102
106
|
"persistent_workers": self.args.dataloader_persistent_workers,
|
103
107
|
"batch_sampler": batch_sampler,
|
104
108
|
}
|
109
|
+
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
|
110
|
+
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
111
|
+
dataloader_params["worker_init_fn"] = seed_worker
|
112
|
+
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
113
|
+
|
105
114
|
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
|
106
115
|
|
107
116
|
def get_eval_dataloader(
|
@@ -9,11 +9,15 @@ from typing import Optional, Union
|
|
9
9
|
|
10
10
|
import numpy as np
|
11
11
|
import pandas as pd
|
12
|
+
import polars as pl
|
12
13
|
import torch
|
13
14
|
import torch.distributed as dist
|
14
15
|
from cehrbert.data_generators.hf_data_generator.meds_utils import CacheFileCollector
|
15
16
|
from cehrbert.runners.runner_util import generate_prepared_ds_path
|
16
17
|
from datasets import concatenate_datasets, load_from_disk
|
18
|
+
from torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook import (
|
19
|
+
batched_powerSGD_hook,
|
20
|
+
)
|
17
21
|
from torch.utils.data import DataLoader
|
18
22
|
from tqdm import tqdm
|
19
23
|
from transformers.trainer_utils import is_main_process
|
@@ -241,7 +245,6 @@ def main():
|
|
241
245
|
SamplePackingCehrGptDataCollator,
|
242
246
|
cehrgpt_args.max_tokens_per_batch,
|
243
247
|
cehrgpt_model.config.max_position_embeddings,
|
244
|
-
add_end_token_in_sample_packing=cehrgpt_args.add_end_token_in_sample_packing,
|
245
248
|
)
|
246
249
|
train_batch_sampler = SamplePackingBatchSampler(
|
247
250
|
lengths=train_set["num_of_concepts"],
|
@@ -332,7 +335,7 @@ def main():
|
|
332
335
|
|
333
336
|
data_loaders = [("train", train_loader), ("test", test_dataloader)]
|
334
337
|
|
335
|
-
ve_token_id = cehrgpt_tokenizer.
|
338
|
+
ve_token_id = cehrgpt_tokenizer.ve_token_id
|
336
339
|
for split, data_loader in data_loaders:
|
337
340
|
# Ensure prediction folder exists
|
338
341
|
feature_output_folder = (
|
@@ -383,6 +386,9 @@ def main():
|
|
383
386
|
if "epoch_times" in batch:
|
384
387
|
batch.pop("epoch_times")
|
385
388
|
|
389
|
+
if "ages" in batch:
|
390
|
+
batch.pop("ages")
|
391
|
+
|
386
392
|
batch = {k: v.to(device) for k, v in batch.items()}
|
387
393
|
# Forward pass
|
388
394
|
cehrgpt_output = cehrgpt_model(
|
@@ -0,0 +1,238 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: cehrgpt
|
3
|
+
Version: 0.1.3
|
4
|
+
Summary: CEHR-GPT: Generating Electronic Health Records with Chronological Patient Timelines
|
5
|
+
Author-email: Chao Pang <chaopang229@gmail.com>, Xinzhuo Jiang <xj2193@cumc.columbia.edu>, Krishna Kalluri <kk3326@cumc.columbia.edu>, Elise Minto <em3697@cumc.columbia.edu>, Jason Patterson <jp3477@cumc.columbia.edu>, Nishanth Parameshwar Pavinkurve <np2689@cumc.columbia.edu>, Karthik Natarajan <kn2174@cumc.columbia.edu>
|
6
|
+
License: MIT License
|
7
|
+
Classifier: Development Status :: 5 - Production/Stable
|
8
|
+
Classifier: Intended Audience :: Developers
|
9
|
+
Classifier: Intended Audience :: Science/Research
|
10
|
+
Classifier: License :: OSI Approved :: MIT License
|
11
|
+
Classifier: Programming Language :: Python :: 3
|
12
|
+
Requires-Python: >=3.10.0
|
13
|
+
Description-Content-Type: text/markdown
|
14
|
+
License-File: LICENSE
|
15
|
+
Requires-Dist: cehrbert>=1.4.8
|
16
|
+
Requires-Dist: cehrbert_data==0.0.11
|
17
|
+
Requires-Dist: openai==1.54.3
|
18
|
+
Requires-Dist: optuna==4.0.0
|
19
|
+
Requires-Dist: transformers==4.44.1
|
20
|
+
Requires-Dist: tokenizers==0.19.0
|
21
|
+
Requires-Dist: peft==0.10.0
|
22
|
+
Requires-Dist: lightgbm
|
23
|
+
Requires-Dist: polars
|
24
|
+
Provides-Extra: dev
|
25
|
+
Requires-Dist: pre-commit; extra == "dev"
|
26
|
+
Requires-Dist: pytest; extra == "dev"
|
27
|
+
Requires-Dist: pytest-cov; extra == "dev"
|
28
|
+
Requires-Dist: pytest-subtests; extra == "dev"
|
29
|
+
Requires-Dist: rootutils; extra == "dev"
|
30
|
+
Requires-Dist: hypothesis; extra == "dev"
|
31
|
+
Requires-Dist: black; extra == "dev"
|
32
|
+
Provides-Extra: flash-attn
|
33
|
+
Requires-Dist: flash_attn; extra == "flash-attn"
|
34
|
+
Dynamic: license-file
|
35
|
+
|
36
|
+
# CEHRGPT
|
37
|
+
|
38
|
+
[](https://pypi.org/project/cehrgpt/)
|
39
|
+

|
40
|
+
[](https://github.com/knatarajan-lab/cehrgpt/actions/workflows/tests.yaml)
|
41
|
+
[](https://github.com/knatarajan-lab/cehrgpt/blob/main/LICENSE)
|
42
|
+
[](https://github.com/knatarajan-lab/cehrgpt/graphs/contributors)
|
43
|
+
|
44
|
+
CEHRGPT is a multi-task foundation model for structured electronic health records (EHR) data that supports three capabilities: feature representation, zero-shot prediction, and synthetic data generation.
|
45
|
+
|
46
|
+
## 🎯 Key Capabilities
|
47
|
+
|
48
|
+
### Feature Representation
|
49
|
+
Extract meaningful patient embeddings from sequences of medical events using **linear probing** techniques for downstream tasks such as disease prediction, patient clustering, and risk stratification.
|
50
|
+
|
51
|
+
### Zero-Shot Prediction
|
52
|
+
Generate outcome predictions directly from prompts without requiring task-specific training, enabling rapid evaluation in low-label clinical settings.
|
53
|
+
|
54
|
+
### Synthetic Data Generation
|
55
|
+
Generate comprehensive patient profiles including demographics, medical history, treatment courses, and outcomes while implementing advanced privacy-preserving techniques to ensure generated data contains no identifiable information.
|
56
|
+
The platform is fully compatible with the OMOP Common Data Model for seamless integration with existing healthcare systems.
|
57
|
+
## 🚀 Installation
|
58
|
+
|
59
|
+
Clone the repository and install dependencies:
|
60
|
+
|
61
|
+
```bash
|
62
|
+
git clone https://github.com/knatarajan-lab/cehrgpt.git
|
63
|
+
cd cehrgpt
|
64
|
+
pip install .
|
65
|
+
```
|
66
|
+
|
67
|
+
## 📋 Prerequisites
|
68
|
+
|
69
|
+
Before getting started, set up the required environment variables:
|
70
|
+
|
71
|
+
```bash
|
72
|
+
export CEHRGPT_HOME=$(git rev-parse --show-toplevel)
|
73
|
+
export OMOP_DIR="" # Path to your OMOP data
|
74
|
+
export CEHR_GPT_DATA_DIR="" # Path for processed data storage
|
75
|
+
export CEHR_GPT_MODEL_DIR="" # Path for model storage
|
76
|
+
```
|
77
|
+
|
78
|
+
Create the dataset cache directory:
|
79
|
+
```bash
|
80
|
+
mkdir $CEHR_GPT_DATA_DIR/dataset_prepared
|
81
|
+
```
|
82
|
+
|
83
|
+
## 🏗️ Model Training
|
84
|
+
|
85
|
+
### Step 1: Generate Pre-training Data from OMOP
|
86
|
+
|
87
|
+
Generate the training data following the [Data Generation Instruction](./data_generation.md).
|
88
|
+
|
89
|
+
### Step 2: Pre-train CEHR-GPT
|
90
|
+
|
91
|
+
Train the foundation model:
|
92
|
+
|
93
|
+
```bash
|
94
|
+
python -u -m cehrgpt.runners.hf_cehrgpt_pretrain_runner \
|
95
|
+
--model_name_or_path $CEHR_GPT_MODEL_DIR \
|
96
|
+
--tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
|
97
|
+
--output_dir $CEHR_GPT_MODEL_DIR \
|
98
|
+
--data_folder "$CEHR_GPT_DATA_DIR/patient_sequence/train" \
|
99
|
+
--dataset_prepared_path "$CEHR_GPT_DATA_DIR/dataset_prepared" \
|
100
|
+
--do_train true --seed 42 \
|
101
|
+
--dataloader_num_workers 16 --dataloader_prefetch_factor 8 \
|
102
|
+
--hidden_size 768 --num_hidden_layers 14 --max_position_embeddings 4096 \
|
103
|
+
--evaluation_strategy epoch --save_strategy epoch \
|
104
|
+
--sample_packing --max_tokens_per_batch 16384 \
|
105
|
+
--warmup_ratio 0.01 --weight_decay 0.01 \
|
106
|
+
--num_train_epochs 50 --learning_rate 0.0002 \
|
107
|
+
--use_early_stopping --early_stopping_threshold 0.001
|
108
|
+
```
|
109
|
+
|
110
|
+
> **Tip**: Increase `max_position_embeddings` for longer context windows based on your use case.
|
111
|
+
|
112
|
+
## 🎯 Feature Representation
|
113
|
+
|
114
|
+
CEHR-GPT enables extraction of meaningful patient embeddings from medical event sequences using **linear probing** techniques for downstream prediction tasks. The feature representation pipeline includes label generation, patient sequence extraction, and linear regression model training on the extracted representations.
|
115
|
+
|
116
|
+
For detailed instructions including cohort creation, patient feature extraction, and linear probing evaluation, please follow the [Feature Representation Guide](./feature_representation.md).
|
117
|
+
|
118
|
+
## 🔮 Zero-Shot Prediction
|
119
|
+
|
120
|
+
CEHR-GPT can generate outcome predictions directly from clinical prompts without requiring task-specific training, making it ideal for rapid evaluation in low-label clinical settings. The zero-shot prediction capability performs time-to-event analysis by processing patient sequences and generating risk predictions based on learned medical patterns.
|
121
|
+
|
122
|
+
For complete setup instructions including label generation, sequence preparation, and prediction execution, please follow the [Zero-Shot Prediction Guide](./zero_shot_prediction.md).
|
123
|
+
|
124
|
+
## 🧬 Synthetic Data Generation
|
125
|
+
|
126
|
+
CEHR-GPT generates comprehensive synthetic patient profiles including demographics, medical history, treatment courses, and outcomes while implementing advanced privacy-preserving techniques. The synthetic data maintains statistical fidelity to real patient populations without containing identifiable information, and outputs are fully compatible with the OMOP Common Data Model.
|
127
|
+
|
128
|
+
For step-by-step instructions on generating synthetic sequences and converting them to OMOP format, please follow the [Synthetic Data Generation Guide](./synthetic_data_generation.md).
|
129
|
+
|
130
|
+
## 📊 MEDS Support
|
131
|
+
|
132
|
+
CEHR-GPT supports the Medical Event Data Standard (MEDS) format for enhanced interoperability.
|
133
|
+
|
134
|
+
### Prerequisites
|
135
|
+
|
136
|
+
Configure MEDS-specific environment variables:
|
137
|
+
|
138
|
+
```bash
|
139
|
+
export CEHR_GPT_MODEL_DIR="" # CEHR-GPT model directory
|
140
|
+
export MEDS_DIR="" # MEDS data directory
|
141
|
+
export MEDS_READER_DIR="" # MEDS reader output directory
|
142
|
+
```
|
143
|
+
|
144
|
+
### Step 1: Create MIMIC MEDS Data
|
145
|
+
|
146
|
+
Transform MIMIC files to MEDS format following the [MEDS_transforms](https://github.com/mmcdermott/MEDS_transforms/) repository instructions.
|
147
|
+
|
148
|
+
### Step 2: Prepare MEDS Reader
|
149
|
+
|
150
|
+
Convert MEDS data for CEHR-GPT compatibility:
|
151
|
+
|
152
|
+
```bash
|
153
|
+
meds_reader_convert $MEDS_DIR $MEDS_READER_DIR --num_threads 10
|
154
|
+
```
|
155
|
+
|
156
|
+
### Step 3: Pre-train with MEDS Data
|
157
|
+
|
158
|
+
Execute pre-training using MEDS format:
|
159
|
+
|
160
|
+
```bash
|
161
|
+
python -u -m cehrgpt.runners.hf_cehrgpt_pretrain_runner \
|
162
|
+
--model_name_or_path $CEHR_GPT_MODEL_DIR \
|
163
|
+
--tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
|
164
|
+
--output_dir $CEHR_GPT_MODEL_DIR \
|
165
|
+
--data_folder $MEDS_READER_DIR \
|
166
|
+
--dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
|
167
|
+
--do_train true --seed 42 \
|
168
|
+
--dataloader_num_workers 16 --dataloader_prefetch_factor 8 \
|
169
|
+
--hidden_size 768 --num_hidden_layers 14 --max_position_embeddings 8192 \
|
170
|
+
--evaluation_strategy epoch --save_strategy epoch \
|
171
|
+
--sample_packing --max_tokens_per_batch 16384 \
|
172
|
+
--warmup_steps 500 --weight_decay 0.01 \
|
173
|
+
--num_train_epochs 50 --learning_rate 0.0002 \
|
174
|
+
--use_early_stopping --early_stopping_threshold 0.001 \
|
175
|
+
--is_data_in_meds --inpatient_att_function_type day \
|
176
|
+
--att_function_type day --include_inpatient_hour_token \
|
177
|
+
--include_auxiliary_token --include_demographic_prompt \
|
178
|
+
--meds_to_cehrbert_conversion_type "MedsToBertMimic4"
|
179
|
+
```
|
180
|
+
|
181
|
+
### Step 4: Generate MEDS Trajectories
|
182
|
+
|
183
|
+
#### Environment Setup
|
184
|
+
|
185
|
+
Configure trajectory generation environment:
|
186
|
+
|
187
|
+
```bash
|
188
|
+
export MEDS_LABEL_COHORT_DIR="" # Cohort labels directory (parquet files)
|
189
|
+
export MEDS_TRAJECTORY_DIR="" # Trajectory output directory
|
190
|
+
```
|
191
|
+
|
192
|
+
#### Generate Synthetic Trajectories
|
193
|
+
|
194
|
+
Create patient trajectories with the trained model:
|
195
|
+
|
196
|
+
```bash
|
197
|
+
python -u -m cehrgpt.generation.cehrgpt_conditional_generation \
|
198
|
+
--cohort_folder $MEDS_LABEL_COHORT_DIR \
|
199
|
+
--data_folder $MEDS_READER_DIR \
|
200
|
+
--dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
|
201
|
+
--model_name_or_path $CEHR_GPT_MODEL_DIR \
|
202
|
+
--tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
|
203
|
+
--output_dir $MEDS_TRAJECTORY_DIR \
|
204
|
+
--per_device_eval_batch_size 16 \
|
205
|
+
--num_of_trajectories_per_sample 2 \
|
206
|
+
--generation_input_length 4096 \
|
207
|
+
--generation_max_new_tokens 4096 \
|
208
|
+
--is_data_in_meds \
|
209
|
+
--att_function_type day --inpatient_att_function_type day \
|
210
|
+
--meds_to_cehrbert_conversion_type MedsToBertMimic4 \
|
211
|
+
--include_auxiliary_token --include_demographic_prompt \
|
212
|
+
--include_inpatient_hour_token
|
213
|
+
```
|
214
|
+
|
215
|
+
> **Important**: Ensure `generation_input_length` + `generation_max_new_tokens` ≤ `max_position_embeddings` (8192).
|
216
|
+
|
217
|
+
#### Parameter Reference
|
218
|
+
|
219
|
+
- `generation_input_length`: Input context length for generation
|
220
|
+
- `generation_max_new_tokens`: Maximum new tokens to generate
|
221
|
+
- `num_of_trajectories_per_sample`: Number of trajectories per patient sample
|
222
|
+
|
223
|
+
## 📖 Citation
|
224
|
+
|
225
|
+
If you use CEHRGPT in your research, please cite:
|
226
|
+
|
227
|
+
```bibtex
|
228
|
+
@article{cehrgpt2024,
|
229
|
+
title={CEHRGPT: Synthetic Data Generation for Electronic Health Records},
|
230
|
+
author={Natarajan, K and others},
|
231
|
+
journal={arXiv preprint arXiv:2402.04400},
|
232
|
+
year={2024}
|
233
|
+
}
|
234
|
+
```
|
235
|
+
|
236
|
+
## 📄 License
|
237
|
+
|
238
|
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|