cehrgpt 0.1.0__tar.gz → 0.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/PKG-INFO +102 -7
  2. cehrgpt-0.1.2/README.md +174 -0
  3. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/pyproject.toml +4 -3
  4. cehrgpt-0.1.2/sample_data/omop_vocab/concept/concept.parquet +0 -0
  5. cehrgpt-0.1.2/scripts/run_linear_prob.sh +260 -0
  6. cehrgpt-0.1.2/src/cehrgpt/analysis/irregularity.py +36 -0
  7. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/data/hf_cehrgpt_dataset.py +1 -0
  8. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/data/hf_cehrgpt_dataset_collator.py +454 -68
  9. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/data/hf_cehrgpt_dataset_mapping.py +232 -17
  10. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/data/sample_packing_sampler.py +36 -6
  11. cehrgpt-0.1.2/src/cehrgpt/generation/cehrgpt_conditional_generation.py +314 -0
  12. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/generation/generate_batch_hf_gpt_sequence.py +15 -3
  13. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/generation/omop_converter_batch.py +32 -2
  14. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/gpt_utils.py +20 -2
  15. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/models/config.py +25 -0
  16. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/models/hf_cehrgpt.py +244 -39
  17. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/models/hf_modeling_outputs.py +1 -0
  18. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/models/special_tokens.py +1 -0
  19. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/models/tokenization_hf_cehrgpt.py +354 -71
  20. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/runners/data_utils.py +131 -5
  21. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/runners/hf_cehrgpt_finetune_runner.py +84 -51
  22. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +59 -7
  23. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +60 -0
  24. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/runners/hyperparameter_search_util.py +6 -7
  25. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/runners/sample_packing_trainer.py +17 -0
  26. cehrgpt-0.1.2/src/cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
  27. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/time_to_event/time_to_event_model.py +2 -13
  28. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
  29. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +80 -62
  30. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt.egg-info/PKG-INFO +102 -7
  31. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt.egg-info/SOURCES.txt +5 -0
  32. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt.egg-info/requires.txt +4 -3
  33. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/tests/integration_tests/runners/hf_cehrgpt_pretrain_runner_test.py +15 -6
  34. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/tests/integration_tests/runners/hf_cehrgpt_pretrain_sample_packing_runner_test.py +7 -0
  35. cehrgpt-0.1.0/README.md +0 -80
  36. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/.github/workflows/build-python.yaml +0 -0
  37. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/.github/workflows/tests.yaml +0 -0
  38. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/.gitignore +0 -0
  39. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/.pre-commit-config.yaml +0 -0
  40. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/LICENSE +0 -0
  41. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/sample_configs/cehrgpt_pretrain_sample_config.yaml +0 -0
  42. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/sample_data/pretrain/patient_sequence.parquet +0 -0
  43. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/sample_data/pretrained_embeddings/pretrained_embedding_concepts.pkl +0 -0
  44. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/sample_data/pretrained_embeddings/pretrained_embedding_vectors.npy +0 -0
  45. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/scripts/level_three_evaluation.sh +0 -0
  46. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/scripts/omop_pipeline.sh +0 -0
  47. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/scripts/pool_generated_sequences.sh +0 -0
  48. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/setup.cfg +0 -0
  49. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/__init__.py +0 -0
  50. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/__init__.py +0 -0
  51. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/analysis/__init__.py +0 -0
  52. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/analysis/privacy/__init__.py +0 -0
  53. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/analysis/privacy/attribute_inference.py +0 -0
  54. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/analysis/privacy/attribute_inference_config.yml +0 -0
  55. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/analysis/privacy/member_inference.py +0 -0
  56. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/analysis/privacy/nearest_neighbor_inference.py +0 -0
  57. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/analysis/privacy/reid_inference.py +0 -0
  58. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/analysis/privacy/utils.py +0 -0
  59. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/cehrgpt_args.py +0 -0
  60. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/data/__init__.py +0 -0
  61. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/generation/__init__.py +0 -0
  62. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/generation/chatgpt_generation.py +0 -0
  63. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/generation/omop_entity.py +0 -0
  64. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/models/__init__.py +0 -0
  65. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/models/pretrained_embeddings.py +0 -0
  66. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/omop/__init__.py +0 -0
  67. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/omop/condition_era.py +0 -0
  68. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/omop/observation_period.py +0 -0
  69. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/omop/omop_argparse.py +0 -0
  70. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/omop/omop_table_builder.py +0 -0
  71. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/omop/queries/__init__.py +0 -0
  72. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/omop/queries/condition_era.py +0 -0
  73. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/omop/queries/observation_period.py +0 -0
  74. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/omop/sample_omop_tables.py +0 -0
  75. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/runners/__init__.py +0 -0
  76. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/runners/gpt_runner_util.py +0 -0
  77. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/simulations/__init__.py +0 -0
  78. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/simulations/generate_plots.py +0 -0
  79. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/simulations/run_simulation.sh +0 -0
  80. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/simulations/time_embedding_simulation.py +0 -0
  81. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/simulations/time_token_simulation.py +0 -0
  82. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/time_to_event/__init__.py +0 -0
  83. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/time_to_event/config/30_day_readmission.yaml +0 -0
  84. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +0 -0
  85. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/time_to_event/config/t2dm_hf.yaml +0 -0
  86. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/time_to_event/time_to_event_utils.py +0 -0
  87. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/tools/__init__.py +0 -0
  88. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/tools/ehrshot_benchmark.py +0 -0
  89. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/tools/generate_causal_patient_split_by_age.py +0 -0
  90. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/tools/generate_pretrained_embeddings.py +0 -0
  91. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/tools/linear_prob/__init__.py +0 -0
  92. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +0 -0
  93. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/tools/merge_synthetic_real_dataasets.py +0 -0
  94. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt/tools/upload_omop_tables.py +0 -0
  95. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt.egg-info/dependency_links.txt +0 -0
  96. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/src/cehrgpt.egg-info/top_level.txt +0 -0
  97. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/tests/__init__.py +0 -0
  98. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/tests/integration_tests/__init__.py +0 -0
  99. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/tests/integration_tests/runners/__init__.py +0 -0
  100. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/tests/integration_tests/runners/hf_cehrgpt_pretrain_sfm_runner_test.py +0 -0
  101. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/tests/unit_tests/__init__.py +0 -0
  102. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/tests/unit_tests/gpt_utils_test.py +0 -0
  103. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/tests/unit_tests/models/__init__.py +0 -0
  104. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/tests/unit_tests/models/model_utils_test.py +0 -0
  105. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/tests/unit_tests/models/tokenization/__init__.py +0 -0
  106. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/tests/unit_tests/models/tokenization/create_bins_with_spline_test.py +0 -0
  107. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/tests/unit_tests/models/tokenization/create_sample_from_bins_test.py +0 -0
  108. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/tests/unit_tests/numeric_concept_statistics_test.py +0 -0
  109. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/tests/unit_tests/runners/__init__.py +0 -0
  110. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/tests/unit_tests/runners/hf_cehrgpt_finetune_runner_test.py +0 -0
  111. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/tests/unit_tests/tokenization_test.py +0 -0
  112. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/tests/unit_tests/tools/__init__.py +0 -0
  113. {cehrgpt-0.1.0 → cehrgpt-0.1.2}/tests/unit_tests/tools/upload_omop_tables_test.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cehrgpt
3
- Version: 0.1.0
3
+ Version: 0.1.2
4
4
  Summary: CEHR-GPT: Generating Electronic Health Records with Chronological Patient Timelines
5
5
  Author-email: Chao Pang <chaopang229@gmail.com>, Xinzhuo Jiang <xj2193@cumc.columbia.edu>, Krishna Kalluri <kk3326@cumc.columbia.edu>, Elise Minto <em3697@cumc.columbia.edu>, Jason Patterson <jp3477@cumc.columbia.edu>, Nishanth Parameshwar Pavinkurve <np2689@cumc.columbia.edu>, Karthik Natarajan <kn2174@cumc.columbia.edu>
6
6
  License: MIT License
@@ -12,14 +12,15 @@ Classifier: Programming Language :: Python :: 3
12
12
  Requires-Python: >=3.10.0
13
13
  Description-Content-Type: text/markdown
14
14
  License-File: LICENSE
15
- Requires-Dist: cehrbert==1.4.1
16
- Requires-Dist: cehrbert_data==0.0.7
15
+ Requires-Dist: cehrbert==1.4.5
16
+ Requires-Dist: cehrbert_data==0.0.11
17
17
  Requires-Dist: openai==1.54.3
18
18
  Requires-Dist: optuna==4.0.0
19
- Requires-Dist: transformers==4.44.0
19
+ Requires-Dist: transformers==4.44.1
20
20
  Requires-Dist: tokenizers==0.19.0
21
21
  Requires-Dist: peft==0.10.0
22
22
  Requires-Dist: lightgbm
23
+ Requires-Dist: polars
23
24
  Provides-Extra: dev
24
25
  Requires-Dist: pre-commit; extra == "dev"
25
26
  Requires-Dist: pytest; extra == "dev"
@@ -36,9 +37,9 @@ Dynamic: license-file
36
37
 
37
38
  [![PyPI - Version](https://img.shields.io/pypi/v/cehrgpt)](https://pypi.org/project/cehrgpt/)
38
39
  ![Python](https://img.shields.io/badge/-Python_3.11-blue?logo=python&logoColor=white)
39
- [![tests](https://github.com/knatarajan-lab/cehrgpt-public/actions/workflows/tests.yaml/badge.svg)](https://github.com/knatarajan-lab/cehrgpt-public/actions/workflows/tests.yml)
40
- [![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray)](https://github.com/knatarajan-lab/cehrgpt-public/blob/main/LICENSE)
41
- [![contributors](https://img.shields.io/github/contributors/knatarajan-lab/cehrgpt-public.svg)](https://github.com/knatarajan-lab/cehrgpt-public/graphs/contributors)
40
+ [![tests](https://github.com/knatarajan-lab/cehrgpt/actions/workflows/tests.yaml/badge.svg)](https://github.com/knatarajan-lab/cehrgpt/actions/workflows/tests.yaml)
41
+ [![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray)](https://github.com/knatarajan-lab/cehrgpt/blob/main/LICENSE)
42
+ [![contributors](https://img.shields.io/github/contributors/knatarajan-lab/cehrgpt.svg)](https://github.com/knatarajan-lab/cehrgpt/graphs/contributors)
42
43
 
43
44
  ## Description
44
45
  CEHRGPT is a synthetic data generation model developed to handle structured electronic health records (EHR) with enhanced privacy and reliability. It leverages state-of-the-art natural language processing techniques to create realistic, anonymized patient data that can be used for research and development without compromising patient privacy.
@@ -104,6 +105,100 @@ sh scripts/omop_pipeline.sh \
104
105
  $OMOP_VOCAB_DIR
105
106
  ```
106
107
 
108
+ # MEDS Support
109
+
110
+ This section demonstrates how to pretrain CEHR-GPT using MIMIC-IV data in the MEDS (Medical Event Data Standard) format.
111
+
112
+ ## Prerequisites
113
+
114
+ Set up the required environment variables before beginning:
115
+
116
+ ```bash
117
+ export CEHR_GPT_MODEL_DIR="" # Path to CEHR-GPT model directory
118
+ export MEDS_DIR="" # Path to MEDS data directory
119
+ export MEDS_READER_DIR="" # Path to MEDS reader output directory
120
+ ```
121
+
122
+ ## Step 1: Create MIMIC MEDS Data
123
+
124
+ Transform your MIMIC files into MEDS format by following the instructions in the [MEDS_transforms](https://github.com/mmcdermott/MEDS_transforms/) repository.
125
+
126
+ ## Step 2: Create the MEDS Reader
127
+
128
+ Convert the MEDS data for use with CEHR-GPT:
129
+
130
+ ```bash
131
+ meds_reader_convert $MEDS_DIR $MEDS_READER_DIR --num_threads 10
132
+ ```
133
+
134
+ ## Step 3: Pretrain CEHR-GPT
135
+
136
+ Run the pretraining process using the prepared MEDS data:
137
+
138
+ ```bash
139
+ python -u -m cehrgpt.runners.hf_cehrgpt_pretrain_runner \
140
+ --model_name_or_path $CEHR_GPT_MODEL_DIR \
141
+ --tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
142
+ --output_dir $CEHR_GPT_MODEL_DIR \
143
+ --data_folder $MEDS_READER_DIR \
144
+ --dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
145
+ --do_train true --seed 42 \
146
+ --dataloader_num_workers 16 --dataloader_prefetch_factor 8 \
147
+ --hidden_size 768 --num_hidden_layers 14 --max_position_embeddings 8192 \
148
+ --evaluation_strategy epoch --save_strategy epoch \
149
+ --sample_packing --max_tokens_per_batch 16384 \
150
+ --warmup_steps 500 --weight_decay 0.01 \
151
+ --num_train_epochs 50 --learning_rate 0.0002 \
152
+ --use_early_stopping --early_stopping_threshold 0.001 \
153
+ --is_data_in_meds --inpatient_att_function_type day \
154
+ --att_function_type day --include_inpatient_hour_token \
155
+ --include_auxiliary_token --include_demographic_prompt \
156
+ --meds_to_cehrbert_conversion_type "MedsToBertMimic4"
157
+ ```
158
+
159
+ ## Step 4: Generate MEDS Trajectories
160
+
161
+ ### Environment Setup for Trajectory Generation
162
+
163
+ Configure additional environment variables for trajectory generation with task labels (`subject_id`, `prediction_time`, `boolean_value` [optional]):
164
+
165
+ ```bash
166
+ # MEDS_LABEL_COHORT_DIR must contain a set of parquet files
167
+ export MEDS_LABEL_COHORT_DIR="" # Path to cohort labels directory
168
+ export MEDS_TRAJECTORY_DIR="" # Path for trajectory output
169
+ ```
170
+
171
+ ### Generate Trajectories
172
+
173
+ Create synthetic patient trajectories using the trained model:
174
+
175
+ > **Important:** The total sequence length (`generation_input_length` + `generation_max_new_tokens`) cannot exceed the `max_position_embeddings` value (8192) defined during pretraining.
176
+
177
+ ```bash
178
+ python -u -m cehrgpt.generation.cehrgpt_conditional_generation \
179
+ --cohort_folder $MEDS_LABEL_COHORT_DIR \
180
+ --data_folder $MEDS_READER_DIR \
181
+ --dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
182
+ --model_name_or_path $CEHR_GPT_MODEL_DIR \
183
+ --tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
184
+ --output_dir $MEDS_TRAJECTORY_DIR \
185
+ --per_device_eval_batch_size 16 \
186
+ --num_of_trajectories_per_sample 2 \
187
+ --generation_input_length 4096 \
188
+ --generation_max_new_tokens 4096 \
189
+ --is_data_in_meds \
190
+ --att_function_type day --inpatient_att_function_type day \
191
+ --meds_to_cehrbert_conversion_type MedsToBertMimic4 \
192
+ --include_auxiliary_token --include_demographic_prompt \
193
+ --include_inpatient_hour_token
194
+ ```
195
+
196
+ ### Parameters Explanation
197
+
198
+ - `generation_input_length`: Controls the length of input context for generation
199
+ - `generation_max_new_tokens`: Maximum number of new tokens to generate
200
+ - `num_of_trajectories_per_sample`: Number of trajectories to generate per patient sample
201
+
107
202
  ## Citation
108
203
  ```
109
204
  @article{cehrgpt2024,
@@ -0,0 +1,174 @@
1
+ # CEHRGPT
2
+
3
+ [![PyPI - Version](https://img.shields.io/pypi/v/cehrgpt)](https://pypi.org/project/cehrgpt/)
4
+ ![Python](https://img.shields.io/badge/-Python_3.11-blue?logo=python&logoColor=white)
5
+ [![tests](https://github.com/knatarajan-lab/cehrgpt/actions/workflows/tests.yaml/badge.svg)](https://github.com/knatarajan-lab/cehrgpt/actions/workflows/tests.yaml)
6
+ [![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray)](https://github.com/knatarajan-lab/cehrgpt/blob/main/LICENSE)
7
+ [![contributors](https://img.shields.io/github/contributors/knatarajan-lab/cehrgpt.svg)](https://github.com/knatarajan-lab/cehrgpt/graphs/contributors)
8
+
9
+ ## Description
10
+ CEHRGPT is a synthetic data generation model developed to handle structured electronic health records (EHR) with enhanced privacy and reliability. It leverages state-of-the-art natural language processing techniques to create realistic, anonymized patient data that can be used for research and development without compromising patient privacy.
11
+
12
+ ## Features
13
+ - **Synthetic Patient Data Generation**: Generates comprehensive patient profiles including demographics, medical history, treatment courses, and outcomes.
14
+ - **Privacy-Preserving**: Implements techniques to ensure the generated data does not reveal identifiable information.
15
+ - **Compatibility with OMOP**: Fully compatible with the OMOP common data model, allowing seamless integration with existing healthcare data systems.
16
+ - **Extensible**: Designed to be adaptable to new datasets and different EHR systems.
17
+
18
+ ## Installation
19
+ To install CEHRGPT, clone this repository and install the required dependencies.
20
+
21
+ ```bash
22
+ git clone https://github.com/knatarajan-lab/cehrgpt.git
23
+ cd cehrgpt
24
+ pip install .
25
+ ```
26
+
27
+ ## Pretrain
28
+ Pretrain cehrgpt using the Hugging Face trainer, the parameters can be found in the sample configuration yaml
29
+ ```bash
30
+ mkdir test_results
31
+ # This is NOT required when streaming is set to true
32
+ mkdir test_dataset_prepared
33
+ python -u -m cehrgpt.runners.hf_cehrgpt_pretrain_runner sample_configs/cehrgpt_pretrain_sample_config.yaml
34
+ ```
35
+
36
+ ## Generate synthetic sequences
37
+ Generate synthetic sequences using the trained model
38
+ ```bash
39
+ export TRANSFORMERS_VERBOSITY=info
40
+ export CUDA_VISIBLE_DEVICES="0"
41
+ python -u -m cehrgpt.generation.generate_batch_hf_gpt_sequence \
42
+ --model_folder test_results \
43
+ --tokenizer_folder test_results \
44
+ --output_folder test_results \
45
+ --num_of_patients 128 \
46
+ --batch_size 32 \
47
+ --buffer_size 128 \
48
+ --context_window 1024 \
49
+ --sampling_strategy TopPStrategy \
50
+ --top_p 1.0 --temperature 1.0 --repetition_penalty 1.0 \
51
+ --epsilon_cutoff 0.00 \
52
+ --demographic_data_path sample_data/pretrain
53
+ ```
54
+
55
+ ## Convert synthetic sequences to OMOP
56
+ ```bash
57
+ # omop converter requires the OHDSI vocabulary
58
+ export OMOP_VOCAB_DIR = ""
59
+ # the omop derived tables need to be built using pyspark
60
+ export SPARK_WORKER_INSTANCES="1"
61
+ export SPARK_WORKER_CORES="8"
62
+ export SPARK_EXECUTOR_CORES="2"
63
+ export SPARK_DRIVER_MEMORY="2g"
64
+ export SPARK_EXECUTOR_MEMORY="2g"
65
+
66
+ # Convert the sequences, create the omop derived tables
67
+ sh scripts/omop_pipeline.sh \
68
+ test_results/top_p10000/generated_sequences/ \
69
+ test_results/top_p10000/restored_omop/ \
70
+ $OMOP_VOCAB_DIR
71
+ ```
72
+
73
+ # MEDS Support
74
+
75
+ This section demonstrates how to pretrain CEHR-GPT using MIMIC-IV data in the MEDS (Medical Event Data Standard) format.
76
+
77
+ ## Prerequisites
78
+
79
+ Set up the required environment variables before beginning:
80
+
81
+ ```bash
82
+ export CEHR_GPT_MODEL_DIR="" # Path to CEHR-GPT model directory
83
+ export MEDS_DIR="" # Path to MEDS data directory
84
+ export MEDS_READER_DIR="" # Path to MEDS reader output directory
85
+ ```
86
+
87
+ ## Step 1: Create MIMIC MEDS Data
88
+
89
+ Transform your MIMIC files into MEDS format by following the instructions in the [MEDS_transforms](https://github.com/mmcdermott/MEDS_transforms/) repository.
90
+
91
+ ## Step 2: Create the MEDS Reader
92
+
93
+ Convert the MEDS data for use with CEHR-GPT:
94
+
95
+ ```bash
96
+ meds_reader_convert $MEDS_DIR $MEDS_READER_DIR --num_threads 10
97
+ ```
98
+
99
+ ## Step 3: Pretrain CEHR-GPT
100
+
101
+ Run the pretraining process using the prepared MEDS data:
102
+
103
+ ```bash
104
+ python -u -m cehrgpt.runners.hf_cehrgpt_pretrain_runner \
105
+ --model_name_or_path $CEHR_GPT_MODEL_DIR \
106
+ --tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
107
+ --output_dir $CEHR_GPT_MODEL_DIR \
108
+ --data_folder $MEDS_READER_DIR \
109
+ --dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
110
+ --do_train true --seed 42 \
111
+ --dataloader_num_workers 16 --dataloader_prefetch_factor 8 \
112
+ --hidden_size 768 --num_hidden_layers 14 --max_position_embeddings 8192 \
113
+ --evaluation_strategy epoch --save_strategy epoch \
114
+ --sample_packing --max_tokens_per_batch 16384 \
115
+ --warmup_steps 500 --weight_decay 0.01 \
116
+ --num_train_epochs 50 --learning_rate 0.0002 \
117
+ --use_early_stopping --early_stopping_threshold 0.001 \
118
+ --is_data_in_meds --inpatient_att_function_type day \
119
+ --att_function_type day --include_inpatient_hour_token \
120
+ --include_auxiliary_token --include_demographic_prompt \
121
+ --meds_to_cehrbert_conversion_type "MedsToBertMimic4"
122
+ ```
123
+
124
+ ## Step 4: Generate MEDS Trajectories
125
+
126
+ ### Environment Setup for Trajectory Generation
127
+
128
+ Configure additional environment variables for trajectory generation with task labels (`subject_id`, `prediction_time`, `boolean_value` [optional]):
129
+
130
+ ```bash
131
+ # MEDS_LABEL_COHORT_DIR must contain a set of parquet files
132
+ export MEDS_LABEL_COHORT_DIR="" # Path to cohort labels directory
133
+ export MEDS_TRAJECTORY_DIR="" # Path for trajectory output
134
+ ```
135
+
136
+ ### Generate Trajectories
137
+
138
+ Create synthetic patient trajectories using the trained model:
139
+
140
+ > **Important:** The total sequence length (`generation_input_length` + `generation_max_new_tokens`) cannot exceed the `max_position_embeddings` value (8192) defined during pretraining.
141
+
142
+ ```bash
143
+ python -u -m cehrgpt.generation.cehrgpt_conditional_generation \
144
+ --cohort_folder $MEDS_LABEL_COHORT_DIR \
145
+ --data_folder $MEDS_READER_DIR \
146
+ --dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
147
+ --model_name_or_path $CEHR_GPT_MODEL_DIR \
148
+ --tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
149
+ --output_dir $MEDS_TRAJECTORY_DIR \
150
+ --per_device_eval_batch_size 16 \
151
+ --num_of_trajectories_per_sample 2 \
152
+ --generation_input_length 4096 \
153
+ --generation_max_new_tokens 4096 \
154
+ --is_data_in_meds \
155
+ --att_function_type day --inpatient_att_function_type day \
156
+ --meds_to_cehrbert_conversion_type MedsToBertMimic4 \
157
+ --include_auxiliary_token --include_demographic_prompt \
158
+ --include_inpatient_hour_token
159
+ ```
160
+
161
+ ### Parameters Explanation
162
+
163
+ - `generation_input_length`: Controls the length of input context for generation
164
+ - `generation_max_new_tokens`: Maximum number of new tokens to generate
165
+ - `num_of_trajectories_per_sample`: Number of trajectories to generate per patient sample
166
+
167
+ ## Citation
168
+ ```
169
+ @article{cehrgpt2024,
170
+ title={CEHRGPT: Synthetic Data Generation for Electronic Health Records},
171
+ author={Natarajan, K and others},
172
+ journal={arXiv preprint arXiv:2402.04400},
173
+ year={2024}
174
+ }
@@ -28,14 +28,15 @@ classifiers = [
28
28
  ]
29
29
 
30
30
  dependencies = [
31
- "cehrbert==1.4.1",
32
- "cehrbert_data==0.0.7",
31
+ "cehrbert==1.4.5",
32
+ "cehrbert_data==0.0.11",
33
33
  "openai==1.54.3",
34
34
  "optuna==4.0.0",
35
- "transformers==4.44.0",
35
+ "transformers==4.44.1",
36
36
  "tokenizers==0.19.0",
37
37
  "peft==0.10.0",
38
38
  "lightgbm",
39
+ "polars",
39
40
  ]
40
41
 
41
42
  [tool.setuptools_scm]
@@ -0,0 +1,260 @@
1
+ #!/bin/sh
2
+
3
+ # Function to display usage information
4
+ usage() {
5
+ echo "Usage: $0 [options]"
6
+ echo ""
7
+ echo "Options:"
8
+ echo " --base_dir=DIR Base directory containing cohorts (required)"
9
+ echo " --dataset_prepared_path=PATH Path to prepared dataset (required)"
10
+ echo " --model_path=PATH Path to pre-trained model and tokenizer (required)"
11
+ echo " --preprocessing_workers=NUM Number of preprocessing workers (required)"
12
+ echo " --batch_size=NUM Batch size for evaluation (required)"
13
+ echo " --output_dir=DIR Output directory for results (required)"
14
+ echo " --model_name=NAME Name for the model output directory (default: cehrgpt_model)"
15
+ echo " --max_tokens_per_batch=NUM Maximum tokens per batch (default: 16384)"
16
+ echo " --torch_type=TYPE Torch data type (default: float32)"
17
+ echo " --disable_sample_packing Disable sample packing (enabled by default)"
18
+ echo ""
19
+ echo "Example:"
20
+ echo " $0 --base_dir=/path/to/cohorts --dataset_prepared_path=/path/to/dataset_prepared \\"
21
+ echo " --model_path=/path/to/model --preprocessing_workers=16 --batch_size=64 \\"
22
+ echo " --output_dir=/path/to/outputs --model_name=my_model --torch_type=float16"
23
+ exit 1
24
+ }
25
+
26
+ # Default values
27
+ MODEL_NAME="cehrgpt_model"
28
+ MAX_TOKENS_PER_BATCH="16384"
29
+ TORCH_TYPE="bfloat16"
30
+ DISABLE_SAMPLE_PACKING="false"
31
+
32
+ # Parse command line arguments
33
+ for arg in "$@"; do
34
+ case $arg in
35
+ --base_dir=*)
36
+ BASE_DIR="${arg#*=}"
37
+ ;;
38
+ --dataset_prepared_path=*)
39
+ DATASET_PREPARED_PATH="${arg#*=}"
40
+ ;;
41
+ --model_path=*)
42
+ MODEL_PATH="${arg#*=}"
43
+ ;;
44
+ --preprocessing_workers=*)
45
+ PREPROCESSING_WORKERS="${arg#*=}"
46
+ ;;
47
+ --batch_size=*)
48
+ BATCH_SIZE="${arg#*=}"
49
+ ;;
50
+ --output_dir=*)
51
+ OUTPUT_DIR="${arg#*=}"
52
+ ;;
53
+ --model_name=*)
54
+ MODEL_NAME="${arg#*=}"
55
+ ;;
56
+ --max_tokens_per_batch=*)
57
+ MAX_TOKENS_PER_BATCH="${arg#*=}"
58
+ ;;
59
+ --torch_type=*)
60
+ TORCH_TYPE="${arg#*=}"
61
+ ;;
62
+ --disable_sample_packing)
63
+ DISABLE_SAMPLE_PACKING="true"
64
+ ;;
65
+ --help|-h)
66
+ usage
67
+ ;;
68
+ *)
69
+ echo "Error: Unknown option: $arg"
70
+ usage
71
+ ;;
72
+ esac
73
+ done
74
+
75
+ # Check for required arguments
76
+ if [ -z "$BASE_DIR" ] || [ -z "$DATASET_PREPARED_PATH" ] || [ -z "$MODEL_PATH" ] || [ -z "$PREPROCESSING_WORKERS" ] || [ -z "$BATCH_SIZE" ] || [ -z "$OUTPUT_DIR" ]; then
77
+ echo "Error: Missing required arguments"
78
+ usage
79
+ fi
80
+
81
+ # Validate arguments
82
+ if [ ! -d "$BASE_DIR" ]; then
83
+ echo "Error: Base directory does not exist: $BASE_DIR"
84
+ exit 1
85
+ fi
86
+
87
+ if [ ! -d "$DATASET_PREPARED_PATH" ]; then
88
+ echo "Error: Dataset prepared path does not exist: $DATASET_PREPARED_PATH"
89
+ exit 1
90
+ fi
91
+
92
+ if [ ! -d "$MODEL_PATH" ]; then
93
+ echo "Error: Model path does not exist: $MODEL_PATH"
94
+ exit 1
95
+ fi
96
+
97
+ # Create output directory if it doesn't exist
98
+ mkdir -p "$OUTPUT_DIR"
99
+
100
+ # Check if preprocessing workers is a number
101
+ if ! [ "$PREPROCESSING_WORKERS" -eq "$PREPROCESSING_WORKERS" ] 2>/dev/null; then
102
+ echo "Error: Preprocessing workers must be a number: $PREPROCESSING_WORKERS"
103
+ exit 1
104
+ fi
105
+
106
+ # Check if batch size is a number
107
+ if ! [ "$BATCH_SIZE" -eq "$BATCH_SIZE" ] 2>/dev/null; then
108
+ echo "Error: Batch size must be a number: $BATCH_SIZE"
109
+ exit 1
110
+ fi
111
+
112
+ # Check if max tokens per batch is a number
113
+ if ! [ "$MAX_TOKENS_PER_BATCH" -eq "$MAX_TOKENS_PER_BATCH" ] 2>/dev/null; then
114
+ echo "Error: Max tokens per batch must be a number: $MAX_TOKENS_PER_BATCH"
115
+ exit 1
116
+ fi
117
+
118
+ # Validate torch_type (common PyTorch data types)
119
+ case "$TORCH_TYPE" in
120
+ float16|float32|float64|bfloat16|int8|int16|int32|int64)
121
+ ;;
122
+ *)
123
+ echo "Error: Invalid torch_type. Supported types: float16, float32, float64, bfloat16, int8, int16, int32, int64"
124
+ exit 1
125
+ ;;
126
+ esac
127
+
128
+ # Validate disable_sample_packing is boolean-like
129
+ if [ "$DISABLE_SAMPLE_PACKING" != "true" ] && [ "$DISABLE_SAMPLE_PACKING" != "false" ]; then
130
+ echo "Error: disable_sample_packing must be 'true' or 'false': $DISABLE_SAMPLE_PACKING"
131
+ exit 1
132
+ fi
133
+
134
+ # Log file setup
135
+ LOG_DIR="$BASE_DIR/logs"
136
+ mkdir -p "$LOG_DIR"
137
+ TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
138
+ MAIN_LOG="$LOG_DIR/run_${TIMESTAMP}.log"
139
+
140
+ # Log function
141
+ log() {
142
+ message="[$(date '+%Y-%m-%d %H:%M:%S')] $1"
143
+ echo "$message" | tee -a "$MAIN_LOG"
144
+ }
145
+
146
+ # Main execution
147
+ log "Starting feature extraction and model training process"
148
+ log "Configuration:"
149
+ log " --base_dir=$BASE_DIR"
150
+ log " --dataset_prepared_path=$DATASET_PREPARED_PATH"
151
+ log " --model_path=$MODEL_PATH"
152
+ log " --preprocessing_workers=$PREPROCESSING_WORKERS"
153
+ log " --batch_size=$BATCH_SIZE"
154
+ log " --output_dir=$OUTPUT_DIR"
155
+ log " --model_name=$MODEL_NAME"
156
+ log " --max_tokens_per_batch=$MAX_TOKENS_PER_BATCH"
157
+ log " --torch_type=$TORCH_TYPE"
158
+ log " --disable_sample_packing=$DISABLE_SAMPLE_PACKING"
159
+
160
+ # Find valid cohorts and write to a temp file
161
+ TEMP_COHORT_LIST="$LOG_DIR/cohort_list_${TIMESTAMP}.txt"
162
+ > "$TEMP_COHORT_LIST" # Clear the file
163
+
164
+ # Find all valid cohorts (directories with train and test subdirectories)
165
+ for cohort_dir in "$BASE_DIR"/*; do
166
+ if [ -d "$cohort_dir" ] && [ -d "$cohort_dir/train" ] && [ -d "$cohort_dir/test" ]; then
167
+ cohort_name=$(basename "$cohort_dir")
168
+ echo "$cohort_name" >> "$TEMP_COHORT_LIST"
169
+ fi
170
+ done
171
+
172
+ # Check if any valid cohorts were found
173
+ if [ ! -s "$TEMP_COHORT_LIST" ]; then
174
+ log "ERROR: No valid cohorts found in $BASE_DIR"
175
+ rm -f "$TEMP_COHORT_LIST"
176
+ exit 1
177
+ fi
178
+
179
+ # Display all cohorts that will be processed
180
+ cohort_count=$(wc -l < "$TEMP_COHORT_LIST")
181
+ log "Found $cohort_count cohorts to process:"
182
+ while read -r cohort; do
183
+ log "- $cohort"
184
+ done < "$TEMP_COHORT_LIST"
185
+
186
+ # Process each cohort sequentially
187
+ while read -r cohort_name; do
188
+ cohort_dir="$OUTPUT_DIR/$cohort_name"
189
+ output_dir="$cohort_dir/$MODEL_NAME"
190
+
191
+ log "===================================================="
192
+ log "Processing cohort: $cohort_name"
193
+ log "===================================================="
194
+
195
+ cohort_log="$LOG_DIR/${cohort_name}_${TIMESTAMP}.log"
196
+
197
+ # Create output directory if it doesn't exist
198
+ mkdir -p "$output_dir"
199
+
200
+ # Prepare command for feature extraction
201
+ FEATURE_CMD="python -u -m cehrgpt.tools.linear_prob.compute_cehrgpt_features \
202
+ --data_folder \"$BASE_DIR/$cohort_name/train/\" \
203
+ --test_data_folder \"$BASE_DIR/$cohort_name/test/\" \
204
+ --dataset_prepared_path \"$DATASET_PREPARED_PATH\" \
205
+ --model_name_or_path \"$MODEL_PATH\" \
206
+ --tokenizer_name_or_path \"$MODEL_PATH\" \
207
+ --output_dir \"$output_dir\" \
208
+ --preprocessing_num_workers \"$PREPROCESSING_WORKERS\" \
209
+ --per_device_eval_batch_size \"$BATCH_SIZE\" \
210
+ --max_tokens_per_batch \"$MAX_TOKENS_PER_BATCH\" \
211
+ --torch_type \"$TORCH_TYPE\""
212
+
213
+ # Add sample packing flag if not disabled
214
+ if [ "$DISABLE_SAMPLE_PACKING" = "false" ]; then
215
+ FEATURE_CMD="$FEATURE_CMD --sample_packing"
216
+ fi
217
+
218
+ # Step 1: Feature extraction
219
+ log "Starting feature extraction for $cohort_name..."
220
+ log "Command: $FEATURE_CMD"
221
+
222
+ eval "$FEATURE_CMD > \"$cohort_log\" 2>&1"
223
+
224
+ feature_extraction_status=$?
225
+ if [ $feature_extraction_status -ne 0 ]; then
226
+ log "ERROR: Feature extraction failed for $cohort_name. Check $cohort_log for details."
227
+ continue
228
+ fi
229
+
230
+ # Step 2: Model training
231
+ log "Starting model training for $cohort_name..."
232
+ log "Command: python -u -m cehrgpt.tools.linear_prob.train_with_cehrgpt_features --features_data_dir $output_dir --output_dir $output_dir"
233
+
234
+ python -u -m cehrgpt.tools.linear_prob.train_with_cehrgpt_features \
235
+ --features_data_dir "$output_dir" \
236
+ --output_dir "$output_dir" \
237
+ >> "$cohort_log" 2>&1
238
+
239
+ echo "Running meds-evaluation for logistic regression for $TASK_NAME..."
240
+ meds-evaluation-cli predictions_path="$output_dir/logistic/test_predictions" \
241
+ output_dir="$output_dir/logistic/"
242
+
243
+ # Check if the second command succeeded
244
+ if [ $? -ne 0 ]; then
245
+ echo "Error: Running meds-evaluation failed for logistic regression for task $TASK_NAME"
246
+ fi
247
+
248
+ model_training_status=$?
249
+ if [ $model_training_status -ne 0 ]; then
250
+ log "ERROR: Model training failed for $cohort_name. Check $cohort_log for details."
251
+ continue
252
+ fi
253
+
254
+ log "Successfully completed processing for $cohort_name"
255
+ done < "$TEMP_COHORT_LIST"
256
+
257
+ # Clean up
258
+ rm -f "$TEMP_COHORT_LIST"
259
+
260
+ log "All processing completed"
@@ -0,0 +1,36 @@
1
+ import os
2
+
3
+ import polars as pl
4
+
5
+ from cehrgpt.gpt_utils import extract_time_interval_in_days, is_att_token
6
+
7
+
8
+ def main(args):
9
+ dataset = pl.read_parquet(os.path.join(args.input_dir, "*.parquet"))
10
+ time_token_frequency_df = (
11
+ dataset.select(pl.col("concept_ids").explode().alias("concept_id"))
12
+ .filter(pl.col("concept_id").map_elements(is_att_token))
13
+ .with_columns(
14
+ pl.col("concept_id")
15
+ .map_elements(extract_time_interval_in_days)
16
+ .alias("time_interval")
17
+ )
18
+ )
19
+ results = time_token_frequency_df.select(
20
+ pl.mean("time_interval").alias("mean"), pl.std("time_interval").alias("std")
21
+ ).to_dicts()[0]
22
+ print(results)
23
+
24
+
25
+ if __name__ == "__main__":
26
+ import argparse
27
+
28
+ parser = argparse.ArgumentParser(description="EHR Irregularity analysis")
29
+ parser.add_argument(
30
+ "--input_dir",
31
+ dest="input_dir",
32
+ action="store",
33
+ help="The path for where the input data folder",
34
+ required=True,
35
+ )
36
+ main(parser.parse_args())
@@ -23,6 +23,7 @@ CEHRGPT_COLUMNS = [
23
23
  "num_of_visits",
24
24
  "values",
25
25
  "value_indicators",
26
+ "epoch_times",
26
27
  ]
27
28
 
28
29
  TRANSFORMER_COLUMNS = ["input_ids"]