cehrgpt 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. cehrgpt/analysis/htn_treatment_pathway.py +546 -0
  2. cehrgpt/analysis/treatment_pathway/__init__.py +0 -0
  3. cehrgpt/analysis/treatment_pathway/depression_treatment_pathway.py +94 -0
  4. cehrgpt/analysis/treatment_pathway/diabetes_treatment_pathway.py +94 -0
  5. cehrgpt/analysis/treatment_pathway/htn_treatment_pathway.py +94 -0
  6. cehrgpt/analysis/treatment_pathway/treatment_pathway.py +631 -0
  7. cehrgpt/data/cehrgpt_data_processor.py +549 -0
  8. cehrgpt/data/hf_cehrgpt_dataset.py +4 -0
  9. cehrgpt/data/hf_cehrgpt_dataset_collator.py +285 -652
  10. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +38 -5
  11. cehrgpt/generation/cehrgpt_conditional_generation.py +2 -0
  12. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +20 -12
  13. cehrgpt/generation/omop_converter_batch.py +11 -4
  14. cehrgpt/gpt_utils.py +73 -3
  15. cehrgpt/models/activations.py +27 -0
  16. cehrgpt/models/config.py +6 -2
  17. cehrgpt/models/gpt2.py +560 -0
  18. cehrgpt/models/hf_cehrgpt.py +183 -460
  19. cehrgpt/models/tokenization_hf_cehrgpt.py +380 -50
  20. cehrgpt/omop/ontology.py +154 -0
  21. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +24 -78
  22. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +48 -44
  23. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +46 -34
  24. cehrgpt/runners/hyperparameter_search_util.py +180 -69
  25. cehrgpt/runners/sample_packing_trainer.py +11 -2
  26. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +8 -2
  27. cehrgpt-0.1.4.dist-info/METADATA +238 -0
  28. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/RECORD +32 -22
  29. cehrgpt-0.1.2.dist-info/METADATA +0 -209
  30. /cehrgpt/tools/{merge_synthetic_real_dataasets.py → merge_synthetic_real_datasets.py} +0 -0
  31. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/WHEEL +0 -0
  32. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/licenses/LICENSE +0 -0
  33. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  from functools import partial
2
- from typing import Callable, Tuple
2
+ from typing import Callable, List, Optional, Tuple, Union
3
3
 
4
4
  import optuna
5
5
  from cehrbert.runners.hf_runner_argument_dataclass import ModelArguments
@@ -64,28 +64,99 @@ class OptunaMetricCallback(TrainerCallback):
64
64
  metrics.update({"optuna_best_metric": metrics["eval_loss"]})
65
65
 
66
66
 
67
- # Define the hyperparameter search space with parameters
68
- def hp_space(
69
- trial: optuna.Trial,
70
- lr_range: Tuple[float, float] = (1e-5, 5e-5),
71
- batch_sizes=None,
72
- weight_decays: Tuple[float, float] = (1e-4, 1e-2),
73
- num_train_epochs: Tuple[float, ...] = 10,
74
- ):
75
- if batch_sizes is None:
76
- batch_sizes = [4, 8]
67
+ def get_suggestion(
68
+ trial,
69
+ hyperparameter_name: str,
70
+ hyperparameters: List[Union[float, int]],
71
+ is_grid: bool = False,
72
+ ) -> Union[float, int]:
73
+ """
74
+ Get hyperparameter suggestion based on search mode.
75
+
76
+ Args:
77
+ trial: Optuna trial object
78
+ hyperparameter_name: Name of the hyperparameter
79
+ hyperparameters: List of hyperparameter values
80
+ is_grid: Whether to use grid search mode
81
+
82
+ Returns:
83
+ Suggested hyperparameter value
84
+
85
+ Raises:
86
+ RuntimeError: If Bayesian mode is used with incorrect number of bounds
87
+ """
88
+ if is_grid:
89
+ return trial.suggest_categorical(hyperparameter_name, hyperparameters)
90
+
91
+ # For Bayesian optimization, we need exactly 2 values (lower and upper bounds)
92
+ if len(hyperparameters) != 2:
93
+ raise RuntimeError(
94
+ f"{hyperparameter_name} must contain exactly two values (lower and upper bound) "
95
+ f"for Bayesian Optimization, but {len(hyperparameters)} values were provided: {hyperparameters}"
96
+ )
97
+
98
+ # Ensure bounds are sorted
99
+ lower, upper = sorted(hyperparameters)
100
+ return trial.suggest_float(hyperparameter_name, lower, upper, log=True)
101
+
102
+
103
+ def hp_space(trial: optuna.Trial, cehrgpt_args: CehrGPTArguments):
104
+ """
105
+ Define the hyperparameter search space.
106
+
107
+ Args:
108
+ trial: Optuna trial object
109
+ cehrgpt_args: CehrGPTArguments
110
+ Returns:
111
+ Dictionary of hyperparameter suggestions
112
+ """
113
+
114
+ is_grid = cehrgpt_args.hyperparameter_tuning_is_grid
115
+ learning_rates = cehrgpt_args.hyperparameter_learning_rates
116
+ weight_decays = cehrgpt_args.hyperparameter_weight_decays
117
+ batch_sizes = cehrgpt_args.hyperparameter_batch_sizes
118
+ num_train_epochs = cehrgpt_args.hyperparameter_num_train_epochs
119
+
77
120
  return {
78
- "learning_rate": trial.suggest_float("learning_rate", *lr_range, log=True),
121
+ "learning_rate": get_suggestion(
122
+ trial, "learning_rate", learning_rates, is_grid
123
+ ),
79
124
  "per_device_train_batch_size": trial.suggest_categorical(
80
125
  "per_device_train_batch_size", batch_sizes
81
126
  ),
82
- "weight_decay": trial.suggest_float("weight_decay", *weight_decays, log=True),
127
+ "weight_decay": get_suggestion(trial, "weight_decay", weight_decays, is_grid),
83
128
  "num_train_epochs": trial.suggest_categorical(
84
129
  "num_train_epochs", num_train_epochs
85
130
  ),
86
131
  }
87
132
 
88
133
 
134
+ def create_grid_search_space(cehrgpt_args: CehrGPTArguments):
135
+ """
136
+ Create the search space dictionary for GridSampler.
137
+
138
+ Args:
139
+ cehrgpt_args: CehrGPTArguments
140
+
141
+ Returns:
142
+ Dictionary defining the grid search space
143
+ """
144
+ return {
145
+ "learning_rate": cehrgpt_args.hyperparameter_learning_rates,
146
+ "weight_decay": cehrgpt_args.hyperparameter_weight_decays,
147
+ "per_device_train_batch_size": cehrgpt_args.hyperparameter_batch_sizes,
148
+ "num_train_epochs": cehrgpt_args.hyperparameter_num_train_epochs,
149
+ }
150
+
151
+
152
+ def calculate_total_combinations(search_space: dict) -> int:
153
+ """Calculate total number of combinations in grid search."""
154
+ total = 1
155
+ for values in search_space.values():
156
+ total *= len(values)
157
+ return total
158
+
159
+
89
160
  def sample_dataset(data: Dataset, percentage: float, seed: int) -> Dataset:
90
161
  """
91
162
  Samples a subset of the given dataset based on a specified percentage.
@@ -113,7 +184,7 @@ def sample_dataset(data: Dataset, percentage: float, seed: int) -> Dataset:
113
184
  returns the "test" portion, which is the specified percentage of the dataset.
114
185
  - Ensure that `percentage` is between 0 and 1 to avoid errors.
115
186
  """
116
- if percentage == 1.0:
187
+ if percentage >= 1.0:
117
188
  return data
118
189
 
119
190
  return data.train_test_split(
@@ -130,14 +201,13 @@ def perform_hyperparameter_search(
130
201
  training_args: TrainingArguments,
131
202
  model_args: ModelArguments,
132
203
  cehrgpt_args: CehrGPTArguments,
133
- ) -> TrainingArguments:
204
+ ) -> Tuple[TrainingArguments, Optional[str]]:
134
205
  """
135
206
  Perform hyperparameter tuning for the CehrGPT model using Optuna with the Hugging Face Trainer.
136
207
 
137
- This function initializes a Trainer with sampled training and validation sets, and performs
138
- a hyperparameter search using Optuna. The search tunes learning rate, batch size, and weight decay
139
- to optimize model performance based on a specified objective metric (e.g., validation loss).
140
- After the search, it updates the provided `TrainingArguments` with the best hyperparameters found.
208
+ This function supports two modes:
209
+ 1. Bayesian Optimization (TPE): Intelligently explores hyperparameter space using bounds
210
+ 2. Grid Search: Exhaustively tests all combinations of discrete values
141
211
 
142
212
  Args:
143
213
  trainer_class: A Trainer or its subclass
@@ -147,15 +217,15 @@ def perform_hyperparameter_search(
147
217
  training_args (TrainingArguments): Configuration for training parameters (e.g., epochs, evaluation strategy).
148
218
  model_args (ModelArguments): Model configuration arguments, including early stopping parameters.
149
219
  cehrgpt_args (CehrGPTArguments): Additional arguments specific to CehrGPT, including hyperparameter
150
- tuning options such as learning rate range, batch sizes, and tuning percentage.
220
+ tuning options and search mode configuration.
151
221
 
152
222
  Returns:
153
- TrainingArguments: Updated `TrainingArguments` instance containing the best hyperparameters found
154
- from the search.
223
+ Tuple[TrainingArguments, Optional[str]]: Updated TrainingArguments with best hyperparameters
224
+ and optional run_id of the best trial.
155
225
 
156
226
  Example:
157
227
  ```
158
- best_training_args = perform_hyperparameter_search(
228
+ best_training_args, run_id = perform_hyperparameter_search(
159
229
  trainer_class=Trainer,
160
230
  model_init=my_model_init,
161
231
  dataset=my_dataset_dict,
@@ -176,50 +246,91 @@ def perform_hyperparameter_search(
176
246
  Logging:
177
247
  Logs the best hyperparameters found at the end of the search.
178
248
  """
179
- if cehrgpt_args.hyperparameter_tuning:
180
- sampled_train = sample_dataset(
181
- dataset["train"],
182
- cehrgpt_args.hyperparameter_tuning_percentage,
183
- training_args.seed,
184
- )
185
- sampled_val = sample_dataset(
186
- dataset["validation"],
187
- cehrgpt_args.hyperparameter_tuning_percentage,
188
- training_args.seed,
189
- )
190
- hyperparam_trainer = trainer_class(
191
- model_init=model_init,
192
- data_collator=data_collator,
193
- train_dataset=sampled_train,
194
- eval_dataset=sampled_val,
195
- callbacks=[
196
- EarlyStoppingCallback(model_args.early_stopping_patience),
197
- OptunaMetricCallback(),
198
- ],
199
- args=training_args,
200
- )
201
- # Perform hyperparameter search
202
- best_trial = hyperparam_trainer.hyperparameter_search(
203
- direction="minimize",
204
- hp_space=partial(
205
- hp_space,
206
- lr_range=(cehrgpt_args.lr_low, cehrgpt_args.lr_high),
207
- weight_decays=(
208
- cehrgpt_args.weight_decays_low,
209
- cehrgpt_args.weight_decays_high,
210
- ),
211
- batch_sizes=cehrgpt_args.hyperparameter_batch_sizes,
212
- num_train_epochs=cehrgpt_args.hyperparameter_num_train_epochs,
213
- ),
214
- backend="optuna",
215
- n_trials=cehrgpt_args.n_trials,
216
- compute_objective=lambda m: m["optuna_best_metric"],
217
- # Ensure reproducibility
218
- sampler=optuna.samplers.TPESampler(seed=training_args.seed),
219
- )
220
- LOG.info("Best hyperparameters: %s", best_trial.hyperparameters)
221
- # Update training arguments with best hyperparameters and set epochs based on adjusted effective epochs
222
- for k, v in best_trial.hyperparameters.items():
223
- setattr(training_args, k, v)
249
+ if not cehrgpt_args.hyperparameter_tuning:
250
+ return training_args, None
251
+
252
+ # Prepare hyperparameters based on mode
253
+ if (
254
+ cehrgpt_args.hyperparameter_tuning_is_grid
255
+ and cehrgpt_args.hyperparameter_tuning_is_grid
256
+ ):
257
+ search_space = create_grid_search_space(cehrgpt_args)
258
+ total_combinations = calculate_total_combinations(search_space)
259
+
260
+ LOG.info(f"Grid search mode: Testing {total_combinations} combinations")
261
+ LOG.info(f"Search space: {search_space}")
262
+
263
+ # Adjust n_trials for grid search if not set appropriately
264
+ if cehrgpt_args.n_trials < total_combinations:
265
+ LOG.warning(
266
+ f"n_trials ({cehrgpt_args.n_trials}) is less than total combinations ({total_combinations}). "
267
+ f"Setting n_trials to {total_combinations} to test all combinations."
268
+ )
269
+ cehrgpt_args.n_trials = total_combinations
270
+
271
+ # Configure sampler based on search mode
272
+ sampler = optuna.samplers.GridSampler(search_space, seed=training_args.seed)
273
+ else:
274
+ LOG.info("Bayesian optimization mode (TPE)")
275
+ LOG.info(f"Learning rate bounds: {cehrgpt_args.hyperparameter_learning_rates}")
276
+ LOG.info(f"Weight decay bounds: {cehrgpt_args.hyperparameter_weight_decays}")
277
+ LOG.info(f"Batch sizes: {cehrgpt_args.hyperparameter_batch_sizes}")
278
+ LOG.info(f"Epochs: {cehrgpt_args.hyperparameter_num_train_epochs}")
279
+ # Configure the TPE sampler
280
+ sampler = optuna.samplers.TPESampler(seed=training_args.seed)
281
+
282
+ # Prepare datasets
283
+ save_total_limit_original = training_args.save_total_limit
284
+ training_args.save_total_limit = 1
285
+
286
+ sampled_train = sample_dataset(
287
+ dataset["train"],
288
+ cehrgpt_args.hyperparameter_tuning_percentage,
289
+ training_args.seed,
290
+ )
291
+ sampled_val = sample_dataset(
292
+ dataset["validation"],
293
+ cehrgpt_args.hyperparameter_tuning_percentage,
294
+ training_args.seed,
295
+ )
296
+ # Create trainer
297
+ hyperparam_trainer = trainer_class(
298
+ model_init=model_init,
299
+ data_collator=data_collator,
300
+ train_dataset=sampled_train,
301
+ eval_dataset=sampled_val,
302
+ callbacks=[
303
+ EarlyStoppingCallback(model_args.early_stopping_patience),
304
+ OptunaMetricCallback(),
305
+ ],
306
+ args=training_args,
307
+ )
308
+
309
+ best_trial = hyperparam_trainer.hyperparameter_search(
310
+ direction="minimize",
311
+ hp_space=partial(
312
+ hp_space,
313
+ cehrgpt_args=cehrgpt_args,
314
+ ),
315
+ backend="optuna",
316
+ n_trials=cehrgpt_args.n_trials,
317
+ compute_objective=lambda m: m["optuna_best_metric"],
318
+ sampler=sampler,
319
+ )
320
+
321
+ # Log results
322
+ LOG.info("=" * 50)
323
+ LOG.info("HYPERPARAMETER SEARCH COMPLETED")
324
+ LOG.info("=" * 50)
325
+ LOG.info(f"Best hyperparameters: {best_trial.hyperparameters}")
326
+ LOG.info(f"Best metric (eval_loss): {best_trial.objective}")
327
+ LOG.info(f"Best run_id: {best_trial.run_id}")
328
+ LOG.info("=" * 50)
329
+
330
+ # Restore original settings and update with best hyperparameters
331
+ training_args.save_total_limit = save_total_limit_original
332
+ for k, v in best_trial.hyperparameters.items():
333
+ setattr(training_args, k, v)
334
+ LOG.info(f"Updated training_args.{k} = {v}")
224
335
 
225
- return training_args
336
+ return training_args, best_trial.run_id
@@ -1,9 +1,10 @@
1
1
  from typing import Optional, Union
2
2
 
3
+ import torch
3
4
  from datasets import Dataset
4
5
  from torch.utils.data import DataLoader
5
6
  from transformers import Trainer
6
- from transformers.trainer_utils import has_length
7
+ from transformers.trainer_utils import has_length, seed_worker
7
8
  from transformers.utils import import_utils, logging
8
9
 
9
10
  from cehrgpt.data.sample_packing_sampler import SamplePackingBatchSampler
@@ -62,7 +63,10 @@ class SamplePackingTrainer(Trainer):
62
63
  if "num_of_concepts" in train_dataset.column_names:
63
64
  lengths = train_dataset["num_of_concepts"]
64
65
  else:
65
- lengths = [len(sample["input_ids"]) for sample in train_dataset]
66
+ lengths = [
67
+ len(sample["input_ids"])
68
+ for sample in train_dataset.select_columns("input_ids")
69
+ ]
66
70
 
67
71
  LOG.info("Finished computing lengths for the train dataset")
68
72
  else:
@@ -102,6 +106,11 @@ class SamplePackingTrainer(Trainer):
102
106
  "persistent_workers": self.args.dataloader_persistent_workers,
103
107
  "batch_sampler": batch_sampler,
104
108
  }
109
+ if not isinstance(train_dataset, torch.utils.data.IterableDataset):
110
+ dataloader_params["drop_last"] = self.args.dataloader_drop_last
111
+ dataloader_params["worker_init_fn"] = seed_worker
112
+ dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
113
+
105
114
  return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
106
115
 
107
116
  def get_eval_dataloader(
@@ -9,11 +9,15 @@ from typing import Optional, Union
9
9
 
10
10
  import numpy as np
11
11
  import pandas as pd
12
+ import polars as pl
12
13
  import torch
13
14
  import torch.distributed as dist
14
15
  from cehrbert.data_generators.hf_data_generator.meds_utils import CacheFileCollector
15
16
  from cehrbert.runners.runner_util import generate_prepared_ds_path
16
17
  from datasets import concatenate_datasets, load_from_disk
18
+ from torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook import (
19
+ batched_powerSGD_hook,
20
+ )
17
21
  from torch.utils.data import DataLoader
18
22
  from tqdm import tqdm
19
23
  from transformers.trainer_utils import is_main_process
@@ -241,7 +245,6 @@ def main():
241
245
  SamplePackingCehrGptDataCollator,
242
246
  cehrgpt_args.max_tokens_per_batch,
243
247
  cehrgpt_model.config.max_position_embeddings,
244
- add_end_token_in_sample_packing=cehrgpt_args.add_end_token_in_sample_packing,
245
248
  )
246
249
  train_batch_sampler = SamplePackingBatchSampler(
247
250
  lengths=train_set["num_of_concepts"],
@@ -332,7 +335,7 @@ def main():
332
335
 
333
336
  data_loaders = [("train", train_loader), ("test", test_dataloader)]
334
337
 
335
- ve_token_id = cehrgpt_tokenizer._convert_token_to_id("[VE]")
338
+ ve_token_id = cehrgpt_tokenizer.ve_token_id
336
339
  for split, data_loader in data_loaders:
337
340
  # Ensure prediction folder exists
338
341
  feature_output_folder = (
@@ -383,6 +386,9 @@ def main():
383
386
  if "epoch_times" in batch:
384
387
  batch.pop("epoch_times")
385
388
 
389
+ if "ages" in batch:
390
+ batch.pop("ages")
391
+
386
392
  batch = {k: v.to(device) for k, v in batch.items()}
387
393
  # Forward pass
388
394
  cehrgpt_output = cehrgpt_model(
@@ -0,0 +1,238 @@
1
+ Metadata-Version: 2.4
2
+ Name: cehrgpt
3
+ Version: 0.1.4
4
+ Summary: CEHR-GPT: Generating Electronic Health Records with Chronological Patient Timelines
5
+ Author-email: Chao Pang <chaopang229@gmail.com>, Xinzhuo Jiang <xj2193@cumc.columbia.edu>, Krishna Kalluri <kk3326@cumc.columbia.edu>, Elise Minto <em3697@cumc.columbia.edu>, Jason Patterson <jp3477@cumc.columbia.edu>, Nishanth Parameshwar Pavinkurve <np2689@cumc.columbia.edu>, Karthik Natarajan <kn2174@cumc.columbia.edu>
6
+ License: MIT License
7
+ Classifier: Development Status :: 5 - Production/Stable
8
+ Classifier: Intended Audience :: Developers
9
+ Classifier: Intended Audience :: Science/Research
10
+ Classifier: License :: OSI Approved :: MIT License
11
+ Classifier: Programming Language :: Python :: 3
12
+ Requires-Python: >=3.10.0
13
+ Description-Content-Type: text/markdown
14
+ License-File: LICENSE
15
+ Requires-Dist: cehrbert>=1.4.8
16
+ Requires-Dist: cehrbert_data>=0.1.1
17
+ Requires-Dist: openai==1.54.3
18
+ Requires-Dist: optuna==4.0.0
19
+ Requires-Dist: transformers==4.44.1
20
+ Requires-Dist: tokenizers==0.19.0
21
+ Requires-Dist: peft==0.10.0
22
+ Requires-Dist: lightgbm
23
+ Requires-Dist: polars
24
+ Provides-Extra: dev
25
+ Requires-Dist: pre-commit; extra == "dev"
26
+ Requires-Dist: pytest; extra == "dev"
27
+ Requires-Dist: pytest-cov; extra == "dev"
28
+ Requires-Dist: pytest-subtests; extra == "dev"
29
+ Requires-Dist: rootutils; extra == "dev"
30
+ Requires-Dist: hypothesis; extra == "dev"
31
+ Requires-Dist: black; extra == "dev"
32
+ Provides-Extra: flash-attn
33
+ Requires-Dist: flash_attn; extra == "flash-attn"
34
+ Dynamic: license-file
35
+
36
+ # CEHRGPT
37
+
38
+ [![PyPI - Version](https://img.shields.io/pypi/v/cehrgpt)](https://pypi.org/project/cehrgpt/)
39
+ ![Python](https://img.shields.io/badge/-Python_3.11-blue?logo=python&logoColor=white)
40
+ [![tests](https://github.com/knatarajan-lab/cehrgpt/actions/workflows/tests.yaml/badge.svg)](https://github.com/knatarajan-lab/cehrgpt/actions/workflows/tests.yaml)
41
+ [![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray)](https://github.com/knatarajan-lab/cehrgpt/blob/main/LICENSE)
42
+ [![contributors](https://img.shields.io/github/contributors/knatarajan-lab/cehrgpt.svg)](https://github.com/knatarajan-lab/cehrgpt/graphs/contributors)
43
+
44
+ CEHRGPT is a multi-task foundation model for structured electronic health records (EHR) data that supports three capabilities: feature representation, zero-shot prediction, and synthetic data generation.
45
+
46
+ ## 🎯 Key Capabilities
47
+
48
+ ### Feature Representation
49
+ Extract meaningful patient embeddings from sequences of medical events using **linear probing** techniques for downstream tasks such as disease prediction, patient clustering, and risk stratification.
50
+
51
+ ### Zero-Shot Prediction
52
+ Generate outcome predictions directly from prompts without requiring task-specific training, enabling rapid evaluation in low-label clinical settings.
53
+
54
+ ### Synthetic Data Generation
55
+ Generate comprehensive patient profiles including demographics, medical history, treatment courses, and outcomes while implementing advanced privacy-preserving techniques to ensure generated data contains no identifiable information.
56
+ The platform is fully compatible with the OMOP Common Data Model for seamless integration with existing healthcare systems.
57
+ ## 🚀 Installation
58
+
59
+ Clone the repository and install dependencies:
60
+
61
+ ```bash
62
+ git clone https://github.com/knatarajan-lab/cehrgpt.git
63
+ cd cehrgpt
64
+ pip install .
65
+ ```
66
+
67
+ ## 📋 Prerequisites
68
+
69
+ Before getting started, set up the required environment variables:
70
+
71
+ ```bash
72
+ export CEHRGPT_HOME=$(git rev-parse --show-toplevel)
73
+ export OMOP_DIR="" # Path to your OMOP data
74
+ export CEHR_GPT_DATA_DIR="" # Path for processed data storage
75
+ export CEHR_GPT_MODEL_DIR="" # Path for model storage
76
+ ```
77
+
78
+ Create the dataset cache directory:
79
+ ```bash
80
+ mkdir $CEHR_GPT_DATA_DIR/dataset_prepared
81
+ ```
82
+
83
+ ## 🏗️ Model Training
84
+
85
+ ### Step 1: Generate Pre-training Data from OMOP
86
+
87
+ Generate the training data following the [Data Generation Instruction](./data_generation.md).
88
+
89
+ ### Step 2: Pre-train CEHR-GPT
90
+
91
+ Train the foundation model:
92
+
93
+ ```bash
94
+ python -u -m cehrgpt.runners.hf_cehrgpt_pretrain_runner \
95
+ --model_name_or_path $CEHR_GPT_MODEL_DIR \
96
+ --tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
97
+ --output_dir $CEHR_GPT_MODEL_DIR \
98
+ --data_folder "$CEHR_GPT_DATA_DIR/patient_sequence/train" \
99
+ --dataset_prepared_path "$CEHR_GPT_DATA_DIR/dataset_prepared" \
100
+ --do_train true --seed 42 \
101
+ --dataloader_num_workers 16 --dataloader_prefetch_factor 8 \
102
+ --hidden_size 768 --num_hidden_layers 14 --max_position_embeddings 4096 \
103
+ --evaluation_strategy epoch --save_strategy epoch \
104
+ --sample_packing --max_tokens_per_batch 16384 \
105
+ --warmup_ratio 0.01 --weight_decay 0.01 \
106
+ --num_train_epochs 50 --learning_rate 0.0002 \
107
+ --use_early_stopping --early_stopping_threshold 0.001
108
+ ```
109
+
110
+ > **Tip**: Increase `max_position_embeddings` for longer context windows based on your use case.
111
+
112
+ ## 🎯 Feature Representation
113
+
114
+ CEHR-GPT enables extraction of meaningful patient embeddings from medical event sequences using **linear probing** techniques for downstream prediction tasks. The feature representation pipeline includes label generation, patient sequence extraction, and linear regression model training on the extracted representations.
115
+
116
+ For detailed instructions including cohort creation, patient feature extraction, and linear probing evaluation, please follow the [Feature Representation Guide](./feature_representation.md).
117
+
118
+ ## 🔮 Zero-Shot Prediction
119
+
120
+ CEHR-GPT can generate outcome predictions directly from clinical prompts without requiring task-specific training, making it ideal for rapid evaluation in low-label clinical settings. The zero-shot prediction capability performs time-to-event analysis by processing patient sequences and generating risk predictions based on learned medical patterns.
121
+
122
+ For complete setup instructions including label generation, sequence preparation, and prediction execution, please follow the [Zero-Shot Prediction Guide](./zero_shot_prediction.md).
123
+
124
+ ## 🧬 Synthetic Data Generation
125
+
126
+ CEHR-GPT generates comprehensive synthetic patient profiles including demographics, medical history, treatment courses, and outcomes while implementing advanced privacy-preserving techniques. The synthetic data maintains statistical fidelity to real patient populations without containing identifiable information, and outputs are fully compatible with the OMOP Common Data Model.
127
+
128
+ For step-by-step instructions on generating synthetic sequences and converting them to OMOP format, please follow the [Synthetic Data Generation Guide](./synthetic_data_generation.md).
129
+
130
+ ## 📊 MEDS Support
131
+
132
+ CEHR-GPT supports the Medical Event Data Standard (MEDS) format for enhanced interoperability.
133
+
134
+ ### Prerequisites
135
+
136
+ Configure MEDS-specific environment variables:
137
+
138
+ ```bash
139
+ export CEHR_GPT_MODEL_DIR="" # CEHR-GPT model directory
140
+ export MEDS_DIR="" # MEDS data directory
141
+ export MEDS_READER_DIR="" # MEDS reader output directory
142
+ ```
143
+
144
+ ### Step 1: Create MIMIC MEDS Data
145
+
146
+ Transform MIMIC files to MEDS format following the [MEDS_transforms](https://github.com/mmcdermott/MEDS_transforms/) repository instructions.
147
+
148
+ ### Step 2: Prepare MEDS Reader
149
+
150
+ Convert MEDS data for CEHR-GPT compatibility:
151
+
152
+ ```bash
153
+ meds_reader_convert $MEDS_DIR $MEDS_READER_DIR --num_threads 10
154
+ ```
155
+
156
+ ### Step 3: Pre-train with MEDS Data
157
+
158
+ Execute pre-training using MEDS format:
159
+
160
+ ```bash
161
+ python -u -m cehrgpt.runners.hf_cehrgpt_pretrain_runner \
162
+ --model_name_or_path $CEHR_GPT_MODEL_DIR \
163
+ --tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
164
+ --output_dir $CEHR_GPT_MODEL_DIR \
165
+ --data_folder $MEDS_READER_DIR \
166
+ --dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
167
+ --do_train true --seed 42 \
168
+ --dataloader_num_workers 16 --dataloader_prefetch_factor 8 \
169
+ --hidden_size 768 --num_hidden_layers 14 --max_position_embeddings 8192 \
170
+ --evaluation_strategy epoch --save_strategy epoch \
171
+ --sample_packing --max_tokens_per_batch 16384 \
172
+ --warmup_steps 500 --weight_decay 0.01 \
173
+ --num_train_epochs 50 --learning_rate 0.0002 \
174
+ --use_early_stopping --early_stopping_threshold 0.001 \
175
+ --is_data_in_meds --inpatient_att_function_type day \
176
+ --att_function_type day --include_inpatient_hour_token \
177
+ --include_auxiliary_token --include_demographic_prompt \
178
+ --meds_to_cehrbert_conversion_type "MedsToBertMimic4"
179
+ ```
180
+
181
+ ### Step 4: Generate MEDS Trajectories
182
+
183
+ #### Environment Setup
184
+
185
+ Configure trajectory generation environment:
186
+
187
+ ```bash
188
+ export MEDS_LABEL_COHORT_DIR="" # Cohort labels directory (parquet files)
189
+ export MEDS_TRAJECTORY_DIR="" # Trajectory output directory
190
+ ```
191
+
192
+ #### Generate Synthetic Trajectories
193
+
194
+ Create patient trajectories with the trained model:
195
+
196
+ ```bash
197
+ python -u -m cehrgpt.generation.cehrgpt_conditional_generation \
198
+ --cohort_folder $MEDS_LABEL_COHORT_DIR \
199
+ --data_folder $MEDS_READER_DIR \
200
+ --dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
201
+ --model_name_or_path $CEHR_GPT_MODEL_DIR \
202
+ --tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
203
+ --output_dir $MEDS_TRAJECTORY_DIR \
204
+ --per_device_eval_batch_size 16 \
205
+ --num_of_trajectories_per_sample 2 \
206
+ --generation_input_length 4096 \
207
+ --generation_max_new_tokens 4096 \
208
+ --is_data_in_meds \
209
+ --att_function_type day --inpatient_att_function_type day \
210
+ --meds_to_cehrbert_conversion_type MedsToBertMimic4 \
211
+ --include_auxiliary_token --include_demographic_prompt \
212
+ --include_inpatient_hour_token
213
+ ```
214
+
215
+ > **Important**: Ensure `generation_input_length` + `generation_max_new_tokens` ≤ `max_position_embeddings` (8192).
216
+
217
+ #### Parameter Reference
218
+
219
+ - `generation_input_length`: Input context length for generation
220
+ - `generation_max_new_tokens`: Maximum new tokens to generate
221
+ - `num_of_trajectories_per_sample`: Number of trajectories per patient sample
222
+
223
+ ## 📖 Citation
224
+
225
+ If you use CEHRGPT in your research, please cite:
226
+
227
+ ```bibtex
228
+ @article{cehrgpt2024,
229
+ title={CEHRGPT: Synthetic Data Generation for Electronic Health Records},
230
+ author={Natarajan, K and others},
231
+ journal={arXiv preprint arXiv:2402.04400},
232
+ year={2024}
233
+ }
234
+ ```
235
+
236
+ ## 📄 License
237
+
238
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.