cehrgpt 0.1.1__tar.gz → 0.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/PKG-INFO +95 -1
- cehrgpt-0.1.2/README.md +174 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/data/hf_cehrgpt_dataset_collator.py +57 -33
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/data/hf_cehrgpt_dataset_mapping.py +22 -9
- cehrgpt-0.1.2/src/cehrgpt/generation/cehrgpt_conditional_generation.py +314 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/generation/generate_batch_hf_gpt_sequence.py +15 -3
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/models/hf_cehrgpt.py +17 -6
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/runners/data_utils.py +17 -6
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/runners/hf_cehrgpt_finetune_runner.py +9 -1
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +12 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +20 -30
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt.egg-info/PKG-INFO +95 -1
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt.egg-info/SOURCES.txt +1 -0
- cehrgpt-0.1.1/README.md +0 -80
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/.github/workflows/build-python.yaml +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/.github/workflows/tests.yaml +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/.gitignore +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/.pre-commit-config.yaml +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/LICENSE +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/pyproject.toml +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/sample_configs/cehrgpt_pretrain_sample_config.yaml +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/sample_data/omop_vocab/concept/concept.parquet +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/sample_data/pretrain/patient_sequence.parquet +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/sample_data/pretrained_embeddings/pretrained_embedding_concepts.pkl +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/sample_data/pretrained_embeddings/pretrained_embedding_vectors.npy +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/scripts/level_three_evaluation.sh +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/scripts/omop_pipeline.sh +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/scripts/pool_generated_sequences.sh +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/scripts/run_linear_prob.sh +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/setup.cfg +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/__init__.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/__init__.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/analysis/__init__.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/analysis/irregularity.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/analysis/privacy/__init__.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/analysis/privacy/attribute_inference.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/analysis/privacy/attribute_inference_config.yml +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/analysis/privacy/member_inference.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/analysis/privacy/nearest_neighbor_inference.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/analysis/privacy/reid_inference.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/analysis/privacy/utils.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/cehrgpt_args.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/data/__init__.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/data/hf_cehrgpt_dataset.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/data/sample_packing_sampler.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/generation/__init__.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/generation/chatgpt_generation.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/generation/omop_converter_batch.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/generation/omop_entity.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/gpt_utils.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/models/__init__.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/models/config.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/models/hf_modeling_outputs.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/models/pretrained_embeddings.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/models/special_tokens.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/models/tokenization_hf_cehrgpt.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/omop/__init__.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/omop/condition_era.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/omop/observation_period.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/omop/omop_argparse.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/omop/omop_table_builder.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/omop/queries/__init__.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/omop/queries/condition_era.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/omop/queries/observation_period.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/omop/sample_omop_tables.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/runners/__init__.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/runners/gpt_runner_util.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/runners/hyperparameter_search_util.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/runners/sample_packing_trainer.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/simulations/__init__.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/simulations/generate_plots.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/simulations/run_simulation.sh +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/simulations/time_embedding_simulation.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/simulations/time_token_simulation.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/time_to_event/__init__.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/time_to_event/config/1_year_cabg.yaml +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/time_to_event/config/30_day_readmission.yaml +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/time_to_event/config/t2dm_hf.yaml +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/time_to_event/time_to_event_model.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/time_to_event/time_to_event_prediction.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/time_to_event/time_to_event_utils.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/tools/__init__.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/tools/ehrshot_benchmark.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/tools/generate_causal_patient_split_by_age.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/tools/generate_pretrained_embeddings.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/tools/linear_prob/__init__.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/tools/merge_synthetic_real_dataasets.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/tools/upload_omop_tables.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt.egg-info/dependency_links.txt +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt.egg-info/requires.txt +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt.egg-info/top_level.txt +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/__init__.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/integration_tests/__init__.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/integration_tests/runners/__init__.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/integration_tests/runners/hf_cehrgpt_pretrain_runner_test.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/integration_tests/runners/hf_cehrgpt_pretrain_sample_packing_runner_test.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/integration_tests/runners/hf_cehrgpt_pretrain_sfm_runner_test.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/unit_tests/__init__.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/unit_tests/gpt_utils_test.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/unit_tests/models/__init__.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/unit_tests/models/model_utils_test.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/unit_tests/models/tokenization/__init__.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/unit_tests/models/tokenization/create_bins_with_spline_test.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/unit_tests/models/tokenization/create_sample_from_bins_test.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/unit_tests/numeric_concept_statistics_test.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/unit_tests/runners/__init__.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/unit_tests/runners/hf_cehrgpt_finetune_runner_test.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/unit_tests/tokenization_test.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/unit_tests/tools/__init__.py +0 -0
- {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/unit_tests/tools/upload_omop_tables_test.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: cehrgpt
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.2
|
4
4
|
Summary: CEHR-GPT: Generating Electronic Health Records with Chronological Patient Timelines
|
5
5
|
Author-email: Chao Pang <chaopang229@gmail.com>, Xinzhuo Jiang <xj2193@cumc.columbia.edu>, Krishna Kalluri <kk3326@cumc.columbia.edu>, Elise Minto <em3697@cumc.columbia.edu>, Jason Patterson <jp3477@cumc.columbia.edu>, Nishanth Parameshwar Pavinkurve <np2689@cumc.columbia.edu>, Karthik Natarajan <kn2174@cumc.columbia.edu>
|
6
6
|
License: MIT License
|
@@ -105,6 +105,100 @@ sh scripts/omop_pipeline.sh \
|
|
105
105
|
$OMOP_VOCAB_DIR
|
106
106
|
```
|
107
107
|
|
108
|
+
# MEDS Support
|
109
|
+
|
110
|
+
This section demonstrates how to pretrain CEHR-GPT using MIMIC-IV data in the MEDS (Medical Event Data Standard) format.
|
111
|
+
|
112
|
+
## Prerequisites
|
113
|
+
|
114
|
+
Set up the required environment variables before beginning:
|
115
|
+
|
116
|
+
```bash
|
117
|
+
export CEHR_GPT_MODEL_DIR="" # Path to CEHR-GPT model directory
|
118
|
+
export MEDS_DIR="" # Path to MEDS data directory
|
119
|
+
export MEDS_READER_DIR="" # Path to MEDS reader output directory
|
120
|
+
```
|
121
|
+
|
122
|
+
## Step 1: Create MIMIC MEDS Data
|
123
|
+
|
124
|
+
Transform your MIMIC files into MEDS format by following the instructions in the [MEDS_transforms](https://github.com/mmcdermott/MEDS_transforms/) repository.
|
125
|
+
|
126
|
+
## Step 2: Create the MEDS Reader
|
127
|
+
|
128
|
+
Convert the MEDS data for use with CEHR-GPT:
|
129
|
+
|
130
|
+
```bash
|
131
|
+
meds_reader_convert $MEDS_DIR $MEDS_READER_DIR --num_threads 10
|
132
|
+
```
|
133
|
+
|
134
|
+
## Step 3: Pretrain CEHR-GPT
|
135
|
+
|
136
|
+
Run the pretraining process using the prepared MEDS data:
|
137
|
+
|
138
|
+
```bash
|
139
|
+
python -u -m cehrgpt.runners.hf_cehrgpt_pretrain_runner \
|
140
|
+
--model_name_or_path $CEHR_GPT_MODEL_DIR \
|
141
|
+
--tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
|
142
|
+
--output_dir $CEHR_GPT_MODEL_DIR \
|
143
|
+
--data_folder $MEDS_READER_DIR \
|
144
|
+
--dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
|
145
|
+
--do_train true --seed 42 \
|
146
|
+
--dataloader_num_workers 16 --dataloader_prefetch_factor 8 \
|
147
|
+
--hidden_size 768 --num_hidden_layers 14 --max_position_embeddings 8192 \
|
148
|
+
--evaluation_strategy epoch --save_strategy epoch \
|
149
|
+
--sample_packing --max_tokens_per_batch 16384 \
|
150
|
+
--warmup_steps 500 --weight_decay 0.01 \
|
151
|
+
--num_train_epochs 50 --learning_rate 0.0002 \
|
152
|
+
--use_early_stopping --early_stopping_threshold 0.001 \
|
153
|
+
--is_data_in_meds --inpatient_att_function_type day \
|
154
|
+
--att_function_type day --include_inpatient_hour_token \
|
155
|
+
--include_auxiliary_token --include_demographic_prompt \
|
156
|
+
--meds_to_cehrbert_conversion_type "MedsToBertMimic4"
|
157
|
+
```
|
158
|
+
|
159
|
+
## Step 4: Generate MEDS Trajectories
|
160
|
+
|
161
|
+
### Environment Setup for Trajectory Generation
|
162
|
+
|
163
|
+
Configure additional environment variables for trajectory generation with task labels (`subject_id`, `prediction_time`, `boolean_value` [optional]):
|
164
|
+
|
165
|
+
```bash
|
166
|
+
# MEDS_LABEL_COHORT_DIR must contain a set of parquet files
|
167
|
+
export MEDS_LABEL_COHORT_DIR="" # Path to cohort labels directory
|
168
|
+
export MEDS_TRAJECTORY_DIR="" # Path for trajectory output
|
169
|
+
```
|
170
|
+
|
171
|
+
### Generate Trajectories
|
172
|
+
|
173
|
+
Create synthetic patient trajectories using the trained model:
|
174
|
+
|
175
|
+
> **Important:** The total sequence length (`generation_input_length` + `generation_max_new_tokens`) cannot exceed the `max_position_embeddings` value (8192) defined during pretraining.
|
176
|
+
|
177
|
+
```bash
|
178
|
+
python -u -m cehrgpt.generation.cehrgpt_conditional_generation \
|
179
|
+
--cohort_folder $MEDS_LABEL_COHORT_DIR \
|
180
|
+
--data_folder $MEDS_READER_DIR \
|
181
|
+
--dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
|
182
|
+
--model_name_or_path $CEHR_GPT_MODEL_DIR \
|
183
|
+
--tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
|
184
|
+
--output_dir $MEDS_TRAJECTORY_DIR \
|
185
|
+
--per_device_eval_batch_size 16 \
|
186
|
+
--num_of_trajectories_per_sample 2 \
|
187
|
+
--generation_input_length 4096 \
|
188
|
+
--generation_max_new_tokens 4096 \
|
189
|
+
--is_data_in_meds \
|
190
|
+
--att_function_type day --inpatient_att_function_type day \
|
191
|
+
--meds_to_cehrbert_conversion_type MedsToBertMimic4 \
|
192
|
+
--include_auxiliary_token --include_demographic_prompt \
|
193
|
+
--include_inpatient_hour_token
|
194
|
+
```
|
195
|
+
|
196
|
+
### Parameters Explanation
|
197
|
+
|
198
|
+
- `generation_input_length`: Controls the length of input context for generation
|
199
|
+
- `generation_max_new_tokens`: Maximum number of new tokens to generate
|
200
|
+
- `num_of_trajectories_per_sample`: Number of trajectories to generate per patient sample
|
201
|
+
|
108
202
|
## Citation
|
109
203
|
```
|
110
204
|
@article{cehrgpt2024,
|
cehrgpt-0.1.2/README.md
ADDED
@@ -0,0 +1,174 @@
|
|
1
|
+
# CEHRGPT
|
2
|
+
|
3
|
+
[](https://pypi.org/project/cehrgpt/)
|
4
|
+

|
5
|
+
[](https://github.com/knatarajan-lab/cehrgpt/actions/workflows/tests.yaml)
|
6
|
+
[](https://github.com/knatarajan-lab/cehrgpt/blob/main/LICENSE)
|
7
|
+
[](https://github.com/knatarajan-lab/cehrgpt/graphs/contributors)
|
8
|
+
|
9
|
+
## Description
|
10
|
+
CEHRGPT is a synthetic data generation model developed to handle structured electronic health records (EHR) with enhanced privacy and reliability. It leverages state-of-the-art natural language processing techniques to create realistic, anonymized patient data that can be used for research and development without compromising patient privacy.
|
11
|
+
|
12
|
+
## Features
|
13
|
+
- **Synthetic Patient Data Generation**: Generates comprehensive patient profiles including demographics, medical history, treatment courses, and outcomes.
|
14
|
+
- **Privacy-Preserving**: Implements techniques to ensure the generated data does not reveal identifiable information.
|
15
|
+
- **Compatibility with OMOP**: Fully compatible with the OMOP common data model, allowing seamless integration with existing healthcare data systems.
|
16
|
+
- **Extensible**: Designed to be adaptable to new datasets and different EHR systems.
|
17
|
+
|
18
|
+
## Installation
|
19
|
+
To install CEHRGPT, clone this repository and install the required dependencies.
|
20
|
+
|
21
|
+
```bash
|
22
|
+
git clone https://github.com/knatarajan-lab/cehrgpt.git
|
23
|
+
cd cehrgpt
|
24
|
+
pip install .
|
25
|
+
```
|
26
|
+
|
27
|
+
## Pretrain
|
28
|
+
Pretrain cehrgpt using the Hugging Face trainer, the parameters can be found in the sample configuration yaml
|
29
|
+
```bash
|
30
|
+
mkdir test_results
|
31
|
+
# This is NOT required when streaming is set to true
|
32
|
+
mkdir test_dataset_prepared
|
33
|
+
python -u -m cehrgpt.runners.hf_cehrgpt_pretrain_runner sample_configs/cehrgpt_pretrain_sample_config.yaml
|
34
|
+
```
|
35
|
+
|
36
|
+
## Generate synthetic sequences
|
37
|
+
Generate synthetic sequences using the trained model
|
38
|
+
```bash
|
39
|
+
export TRANSFORMERS_VERBOSITY=info
|
40
|
+
export CUDA_VISIBLE_DEVICES="0"
|
41
|
+
python -u -m cehrgpt.generation.generate_batch_hf_gpt_sequence \
|
42
|
+
--model_folder test_results \
|
43
|
+
--tokenizer_folder test_results \
|
44
|
+
--output_folder test_results \
|
45
|
+
--num_of_patients 128 \
|
46
|
+
--batch_size 32 \
|
47
|
+
--buffer_size 128 \
|
48
|
+
--context_window 1024 \
|
49
|
+
--sampling_strategy TopPStrategy \
|
50
|
+
--top_p 1.0 --temperature 1.0 --repetition_penalty 1.0 \
|
51
|
+
--epsilon_cutoff 0.00 \
|
52
|
+
--demographic_data_path sample_data/pretrain
|
53
|
+
```
|
54
|
+
|
55
|
+
## Convert synthetic sequences to OMOP
|
56
|
+
```bash
|
57
|
+
# omop converter requires the OHDSI vocabulary
|
58
|
+
export OMOP_VOCAB_DIR = ""
|
59
|
+
# the omop derived tables need to be built using pyspark
|
60
|
+
export SPARK_WORKER_INSTANCES="1"
|
61
|
+
export SPARK_WORKER_CORES="8"
|
62
|
+
export SPARK_EXECUTOR_CORES="2"
|
63
|
+
export SPARK_DRIVER_MEMORY="2g"
|
64
|
+
export SPARK_EXECUTOR_MEMORY="2g"
|
65
|
+
|
66
|
+
# Convert the sequences, create the omop derived tables
|
67
|
+
sh scripts/omop_pipeline.sh \
|
68
|
+
test_results/top_p10000/generated_sequences/ \
|
69
|
+
test_results/top_p10000/restored_omop/ \
|
70
|
+
$OMOP_VOCAB_DIR
|
71
|
+
```
|
72
|
+
|
73
|
+
# MEDS Support
|
74
|
+
|
75
|
+
This section demonstrates how to pretrain CEHR-GPT using MIMIC-IV data in the MEDS (Medical Event Data Standard) format.
|
76
|
+
|
77
|
+
## Prerequisites
|
78
|
+
|
79
|
+
Set up the required environment variables before beginning:
|
80
|
+
|
81
|
+
```bash
|
82
|
+
export CEHR_GPT_MODEL_DIR="" # Path to CEHR-GPT model directory
|
83
|
+
export MEDS_DIR="" # Path to MEDS data directory
|
84
|
+
export MEDS_READER_DIR="" # Path to MEDS reader output directory
|
85
|
+
```
|
86
|
+
|
87
|
+
## Step 1: Create MIMIC MEDS Data
|
88
|
+
|
89
|
+
Transform your MIMIC files into MEDS format by following the instructions in the [MEDS_transforms](https://github.com/mmcdermott/MEDS_transforms/) repository.
|
90
|
+
|
91
|
+
## Step 2: Create the MEDS Reader
|
92
|
+
|
93
|
+
Convert the MEDS data for use with CEHR-GPT:
|
94
|
+
|
95
|
+
```bash
|
96
|
+
meds_reader_convert $MEDS_DIR $MEDS_READER_DIR --num_threads 10
|
97
|
+
```
|
98
|
+
|
99
|
+
## Step 3: Pretrain CEHR-GPT
|
100
|
+
|
101
|
+
Run the pretraining process using the prepared MEDS data:
|
102
|
+
|
103
|
+
```bash
|
104
|
+
python -u -m cehrgpt.runners.hf_cehrgpt_pretrain_runner \
|
105
|
+
--model_name_or_path $CEHR_GPT_MODEL_DIR \
|
106
|
+
--tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
|
107
|
+
--output_dir $CEHR_GPT_MODEL_DIR \
|
108
|
+
--data_folder $MEDS_READER_DIR \
|
109
|
+
--dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
|
110
|
+
--do_train true --seed 42 \
|
111
|
+
--dataloader_num_workers 16 --dataloader_prefetch_factor 8 \
|
112
|
+
--hidden_size 768 --num_hidden_layers 14 --max_position_embeddings 8192 \
|
113
|
+
--evaluation_strategy epoch --save_strategy epoch \
|
114
|
+
--sample_packing --max_tokens_per_batch 16384 \
|
115
|
+
--warmup_steps 500 --weight_decay 0.01 \
|
116
|
+
--num_train_epochs 50 --learning_rate 0.0002 \
|
117
|
+
--use_early_stopping --early_stopping_threshold 0.001 \
|
118
|
+
--is_data_in_meds --inpatient_att_function_type day \
|
119
|
+
--att_function_type day --include_inpatient_hour_token \
|
120
|
+
--include_auxiliary_token --include_demographic_prompt \
|
121
|
+
--meds_to_cehrbert_conversion_type "MedsToBertMimic4"
|
122
|
+
```
|
123
|
+
|
124
|
+
## Step 4: Generate MEDS Trajectories
|
125
|
+
|
126
|
+
### Environment Setup for Trajectory Generation
|
127
|
+
|
128
|
+
Configure additional environment variables for trajectory generation with task labels (`subject_id`, `prediction_time`, `boolean_value` [optional]):
|
129
|
+
|
130
|
+
```bash
|
131
|
+
# MEDS_LABEL_COHORT_DIR must contain a set of parquet files
|
132
|
+
export MEDS_LABEL_COHORT_DIR="" # Path to cohort labels directory
|
133
|
+
export MEDS_TRAJECTORY_DIR="" # Path for trajectory output
|
134
|
+
```
|
135
|
+
|
136
|
+
### Generate Trajectories
|
137
|
+
|
138
|
+
Create synthetic patient trajectories using the trained model:
|
139
|
+
|
140
|
+
> **Important:** The total sequence length (`generation_input_length` + `generation_max_new_tokens`) cannot exceed the `max_position_embeddings` value (8192) defined during pretraining.
|
141
|
+
|
142
|
+
```bash
|
143
|
+
python -u -m cehrgpt.generation.cehrgpt_conditional_generation \
|
144
|
+
--cohort_folder $MEDS_LABEL_COHORT_DIR \
|
145
|
+
--data_folder $MEDS_READER_DIR \
|
146
|
+
--dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
|
147
|
+
--model_name_or_path $CEHR_GPT_MODEL_DIR \
|
148
|
+
--tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
|
149
|
+
--output_dir $MEDS_TRAJECTORY_DIR \
|
150
|
+
--per_device_eval_batch_size 16 \
|
151
|
+
--num_of_trajectories_per_sample 2 \
|
152
|
+
--generation_input_length 4096 \
|
153
|
+
--generation_max_new_tokens 4096 \
|
154
|
+
--is_data_in_meds \
|
155
|
+
--att_function_type day --inpatient_att_function_type day \
|
156
|
+
--meds_to_cehrbert_conversion_type MedsToBertMimic4 \
|
157
|
+
--include_auxiliary_token --include_demographic_prompt \
|
158
|
+
--include_inpatient_hour_token
|
159
|
+
```
|
160
|
+
|
161
|
+
### Parameters Explanation
|
162
|
+
|
163
|
+
- `generation_input_length`: Controls the length of input context for generation
|
164
|
+
- `generation_max_new_tokens`: Maximum number of new tokens to generate
|
165
|
+
- `num_of_trajectories_per_sample`: Number of trajectories to generate per patient sample
|
166
|
+
|
167
|
+
## Citation
|
168
|
+
```
|
169
|
+
@article{cehrgpt2024,
|
170
|
+
title={CEHRGPT: Synthetic Data Generation for Electronic Health Records},
|
171
|
+
author={Natarajan, K and others},
|
172
|
+
journal={arXiv preprint arXiv:2402.04400},
|
173
|
+
year={2024}
|
174
|
+
}
|
@@ -162,6 +162,22 @@ class CehrGptDataCollator:
|
|
162
162
|
f"batch['input_ids']: {batch['input_ids']} "
|
163
163
|
)
|
164
164
|
|
165
|
+
if "epoch_times" in examples[0]:
|
166
|
+
batch_epoch_times = [
|
167
|
+
self._try_reverse_tensor(
|
168
|
+
self._convert_to_tensor(example["epoch_times"])
|
169
|
+
)
|
170
|
+
for example in examples
|
171
|
+
]
|
172
|
+
# Pad sequences to the max length in the batch
|
173
|
+
batch["epoch_times"] = self._try_reverse_tensor(
|
174
|
+
pad_sequence(
|
175
|
+
batch_epoch_times,
|
176
|
+
batch_first=True,
|
177
|
+
padding_value=0,
|
178
|
+
).to(torch.float32)
|
179
|
+
)
|
180
|
+
|
165
181
|
if "position_ids" in examples[0]:
|
166
182
|
batch_position_ids = [
|
167
183
|
self._try_reverse_tensor(
|
@@ -663,7 +679,9 @@ class CehrGptDataCollator:
|
|
663
679
|
|
664
680
|
# Subtract one for the [END] token when sample_packing is not enabled
|
665
681
|
new_max_length = (
|
666
|
-
max_length_allowed
|
682
|
+
max_length_allowed - 1
|
683
|
+
if not sample_packing and self.pretraining
|
684
|
+
else max_length_allowed
|
667
685
|
)
|
668
686
|
|
669
687
|
if self.include_ttv_prediction:
|
@@ -685,13 +703,20 @@ class CehrGptDataCollator:
|
|
685
703
|
|
686
704
|
# Return the record directly if the actual sequence length is less than the max sequence
|
687
705
|
if seq_length <= new_max_length:
|
688
|
-
if not sample_packing:
|
706
|
+
if not sample_packing and self.pretraining:
|
689
707
|
record["input_ids"] = torch.concat(
|
690
708
|
[
|
691
709
|
self._convert_to_tensor(record["input_ids"]),
|
692
710
|
self._convert_to_tensor([eos_token]),
|
693
711
|
]
|
694
712
|
)
|
713
|
+
if "epoch_times" in record:
|
714
|
+
record["epoch_times"] = torch.concat(
|
715
|
+
[
|
716
|
+
self._convert_to_tensor(record["epoch_times"]),
|
717
|
+
self._convert_to_tensor([record["epoch_times"][-1]]),
|
718
|
+
]
|
719
|
+
)
|
695
720
|
if self.include_values:
|
696
721
|
record["value_indicators"] = torch.concat(
|
697
722
|
[
|
@@ -727,6 +752,10 @@ class CehrGptDataCollator:
|
|
727
752
|
record["input_ids"] = self._convert_to_tensor(
|
728
753
|
record["input_ids"][start_index : end_index + 1]
|
729
754
|
)
|
755
|
+
if "epoch_times" in record:
|
756
|
+
record["epoch_times"] = self._convert_to_tensor(
|
757
|
+
record["epoch_times"][start_index : end_index + 1]
|
758
|
+
)
|
730
759
|
if self.include_values:
|
731
760
|
record["value_indicators"] = self._convert_to_tensor(
|
732
761
|
record["value_indicators"][start_index : end_index + 1]
|
@@ -760,6 +789,11 @@ class CehrGptDataCollator:
|
|
760
789
|
if sample_packing and "position_ids" in record:
|
761
790
|
record["position_ids"] = record["position_ids"][0:end_index]
|
762
791
|
|
792
|
+
if "epoch_times" in record:
|
793
|
+
record["epoch_times"] = self._convert_to_tensor(
|
794
|
+
record["epoch_times"][0:end_index]
|
795
|
+
)
|
796
|
+
|
763
797
|
if self.include_values:
|
764
798
|
record["value_indicators"] = self._convert_to_tensor(
|
765
799
|
record["value_indicators"][0:end_index]
|
@@ -792,6 +826,17 @@ class CehrGptDataCollator:
|
|
792
826
|
),
|
793
827
|
]
|
794
828
|
)
|
829
|
+
if "epoch_times" in record:
|
830
|
+
record["epoch_times"] = torch.concat(
|
831
|
+
[
|
832
|
+
torch.zeros(
|
833
|
+
[record["epoch_times"][0]], dtype=torch.float32
|
834
|
+
),
|
835
|
+
self._convert_to_tensor(
|
836
|
+
record["epoch_times"][token_index:seq_length]
|
837
|
+
),
|
838
|
+
]
|
839
|
+
)
|
795
840
|
if self.include_values:
|
796
841
|
record["value_indicators"] = torch.concat(
|
797
842
|
[
|
@@ -830,7 +875,7 @@ class CehrGptDataCollator:
|
|
830
875
|
)
|
831
876
|
break
|
832
877
|
else:
|
833
|
-
start_index = seq_length - new_max_length
|
878
|
+
start_index = max(seq_length - new_max_length, 0)
|
834
879
|
end_index = seq_length
|
835
880
|
for i in range(start_index, end_index):
|
836
881
|
current_token = record["input_ids"][i]
|
@@ -842,6 +887,11 @@ class CehrGptDataCollator:
|
|
842
887
|
]
|
843
888
|
if sample_packing and "position_ids" in record:
|
844
889
|
record["position_ids"] = record["position_ids"][i:end_index]
|
890
|
+
|
891
|
+
if "epoch_times" in record:
|
892
|
+
record["epoch_times"] = self._convert_to_tensor(
|
893
|
+
record["epoch_times"][i:end_index]
|
894
|
+
)
|
845
895
|
if self.include_values:
|
846
896
|
record["value_indicators"] = record["value_indicators"][
|
847
897
|
i:end_index
|
@@ -863,6 +913,10 @@ class CehrGptDataCollator:
|
|
863
913
|
]
|
864
914
|
if sample_packing and "position_ids" in record:
|
865
915
|
record["position_ids"] = record["position_ids"][-new_max_length:]
|
916
|
+
if "epoch_times" in record:
|
917
|
+
record["epoch_times"] = self._convert_to_tensor(
|
918
|
+
record["epoch_times"][-new_max_length:]
|
919
|
+
)
|
866
920
|
if self.include_values:
|
867
921
|
record["value_indicators"] = record["value_indicators"][
|
868
922
|
-new_max_length:
|
@@ -873,36 +927,6 @@ class CehrGptDataCollator:
|
|
873
927
|
-new_max_length:
|
874
928
|
]
|
875
929
|
|
876
|
-
if not sample_packing:
|
877
|
-
# Finally we add the end token to the end of the sequence
|
878
|
-
record["input_ids"] = torch.concat(
|
879
|
-
[
|
880
|
-
self._convert_to_tensor(record["input_ids"]),
|
881
|
-
self._convert_to_tensor([eos_token]),
|
882
|
-
]
|
883
|
-
)
|
884
|
-
if self.include_values:
|
885
|
-
record["value_indicators"] = torch.concat(
|
886
|
-
[
|
887
|
-
self._convert_to_tensor(record["value_indicators"]),
|
888
|
-
self._convert_to_tensor([False]),
|
889
|
-
]
|
890
|
-
).to(torch.bool)
|
891
|
-
record["values"] = torch.concat(
|
892
|
-
[
|
893
|
-
self._convert_to_tensor(record["values"]),
|
894
|
-
self._convert_to_tensor(
|
895
|
-
[self.tokenizer.pad_value_token_id]
|
896
|
-
),
|
897
|
-
]
|
898
|
-
)
|
899
|
-
if self.include_ttv_prediction:
|
900
|
-
record["time_to_visits"] = torch.concat(
|
901
|
-
[
|
902
|
-
record["time_to_visits"],
|
903
|
-
self._convert_to_tensor([-100.0]),
|
904
|
-
]
|
905
|
-
)
|
906
930
|
return record
|
907
931
|
|
908
932
|
|
@@ -21,7 +21,6 @@ from cehrbert_data.const.artificial_tokens import (
|
|
21
21
|
DISCHARGE_UNKNOWN_TOKEN,
|
22
22
|
GENDER_UNKNOWN_TOKEN,
|
23
23
|
RACE_UNKNOWN_TOKEN,
|
24
|
-
VISIT_UNKNOWN_TOKEN,
|
25
24
|
)
|
26
25
|
from cehrbert_data.const.common import NA
|
27
26
|
from cehrbert_data.decorators.patient_event_decorator_base import get_att_function
|
@@ -47,10 +46,16 @@ CEHRGPT_COLUMNS = [
|
|
47
46
|
]
|
48
47
|
|
49
48
|
|
50
|
-
def convert_date_to_posix_time(index_date: datetime.date) -> float:
|
51
|
-
|
52
|
-
|
53
|
-
|
49
|
+
def convert_date_to_posix_time(index_date: Union[datetime.date, int, float]) -> float:
|
50
|
+
if isinstance(index_date, datetime.date):
|
51
|
+
return (
|
52
|
+
datetime.datetime.combine(index_date, datetime.datetime.min.time())
|
53
|
+
.replace(tzinfo=datetime.timezone.utc)
|
54
|
+
.timestamp()
|
55
|
+
)
|
56
|
+
elif isinstance(index_date, datetime.datetime):
|
57
|
+
return index_date.replace(tzinfo=datetime.timezone.utc).timestamp()
|
58
|
+
return index_date
|
54
59
|
|
55
60
|
|
56
61
|
class DatasetMappingDecorator(DatasetMapping):
|
@@ -128,7 +133,9 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
|
|
128
133
|
cehrgpt_record["concept_as_values"].append(concept_as_value)
|
129
134
|
cehrgpt_record["units"].append(unit)
|
130
135
|
cehrgpt_record["is_numeric_types"].append(is_numeric_type)
|
131
|
-
cehrgpt_record["epoch_times"].append(
|
136
|
+
cehrgpt_record["epoch_times"].append(
|
137
|
+
time.replace(tzinfo=datetime.timezone.utc).timestamp()
|
138
|
+
)
|
132
139
|
|
133
140
|
def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
134
141
|
cehrgpt_record = {
|
@@ -360,7 +367,9 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
|
|
360
367
|
cehrgpt_record["num_of_visits"] = len(visits)
|
361
368
|
|
362
369
|
if record.get("index_date", None) is not None:
|
363
|
-
cehrgpt_record["index_date"] =
|
370
|
+
cehrgpt_record["index_date"] = (
|
371
|
+
record["index_date"].replace(tzinfo=datetime.timezone.utc).timestamp()
|
372
|
+
)
|
364
373
|
if record.get("label", None) is not None:
|
365
374
|
cehrgpt_record["label"] = record["label"]
|
366
375
|
if record.get("age_at_index", None) is not None:
|
@@ -529,9 +538,13 @@ class ExtractTokenizedSequenceDataMapping:
|
|
529
538
|
prediction_start_end_times = [
|
530
539
|
(
|
531
540
|
self._calculate_prediction_start_time(
|
532
|
-
prediction_time_label_map["index_date"]
|
541
|
+
prediction_time_label_map["index_date"]
|
542
|
+
.replace(tzinfo=datetime.timezone.utc)
|
543
|
+
.timestamp()
|
533
544
|
),
|
534
|
-
prediction_time_label_map["index_date"]
|
545
|
+
prediction_time_label_map["index_date"]
|
546
|
+
.replace(tzinfo=datetime.timezone.utc)
|
547
|
+
.timestamp(),
|
535
548
|
prediction_time_label_map["label"],
|
536
549
|
)
|
537
550
|
for prediction_time_label_map in prediction_times
|