cehrgpt 0.1.1__tar.gz → 0.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/PKG-INFO +95 -1
  2. cehrgpt-0.1.2/README.md +174 -0
  3. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/data/hf_cehrgpt_dataset_collator.py +57 -33
  4. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/data/hf_cehrgpt_dataset_mapping.py +22 -9
  5. cehrgpt-0.1.2/src/cehrgpt/generation/cehrgpt_conditional_generation.py +314 -0
  6. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/generation/generate_batch_hf_gpt_sequence.py +15 -3
  7. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/models/hf_cehrgpt.py +17 -6
  8. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/runners/data_utils.py +17 -6
  9. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/runners/hf_cehrgpt_finetune_runner.py +9 -1
  10. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +12 -0
  11. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +20 -30
  12. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt.egg-info/PKG-INFO +95 -1
  13. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt.egg-info/SOURCES.txt +1 -0
  14. cehrgpt-0.1.1/README.md +0 -80
  15. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/.github/workflows/build-python.yaml +0 -0
  16. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/.github/workflows/tests.yaml +0 -0
  17. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/.gitignore +0 -0
  18. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/.pre-commit-config.yaml +0 -0
  19. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/LICENSE +0 -0
  20. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/pyproject.toml +0 -0
  21. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/sample_configs/cehrgpt_pretrain_sample_config.yaml +0 -0
  22. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/sample_data/omop_vocab/concept/concept.parquet +0 -0
  23. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/sample_data/pretrain/patient_sequence.parquet +0 -0
  24. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/sample_data/pretrained_embeddings/pretrained_embedding_concepts.pkl +0 -0
  25. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/sample_data/pretrained_embeddings/pretrained_embedding_vectors.npy +0 -0
  26. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/scripts/level_three_evaluation.sh +0 -0
  27. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/scripts/omop_pipeline.sh +0 -0
  28. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/scripts/pool_generated_sequences.sh +0 -0
  29. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/scripts/run_linear_prob.sh +0 -0
  30. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/setup.cfg +0 -0
  31. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/__init__.py +0 -0
  32. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/__init__.py +0 -0
  33. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/analysis/__init__.py +0 -0
  34. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/analysis/irregularity.py +0 -0
  35. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/analysis/privacy/__init__.py +0 -0
  36. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/analysis/privacy/attribute_inference.py +0 -0
  37. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/analysis/privacy/attribute_inference_config.yml +0 -0
  38. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/analysis/privacy/member_inference.py +0 -0
  39. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/analysis/privacy/nearest_neighbor_inference.py +0 -0
  40. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/analysis/privacy/reid_inference.py +0 -0
  41. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/analysis/privacy/utils.py +0 -0
  42. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/cehrgpt_args.py +0 -0
  43. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/data/__init__.py +0 -0
  44. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/data/hf_cehrgpt_dataset.py +0 -0
  45. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/data/sample_packing_sampler.py +0 -0
  46. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/generation/__init__.py +0 -0
  47. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/generation/chatgpt_generation.py +0 -0
  48. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/generation/omop_converter_batch.py +0 -0
  49. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/generation/omop_entity.py +0 -0
  50. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/gpt_utils.py +0 -0
  51. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/models/__init__.py +0 -0
  52. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/models/config.py +0 -0
  53. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/models/hf_modeling_outputs.py +0 -0
  54. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/models/pretrained_embeddings.py +0 -0
  55. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/models/special_tokens.py +0 -0
  56. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/models/tokenization_hf_cehrgpt.py +0 -0
  57. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/omop/__init__.py +0 -0
  58. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/omop/condition_era.py +0 -0
  59. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/omop/observation_period.py +0 -0
  60. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/omop/omop_argparse.py +0 -0
  61. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/omop/omop_table_builder.py +0 -0
  62. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/omop/queries/__init__.py +0 -0
  63. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/omop/queries/condition_era.py +0 -0
  64. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/omop/queries/observation_period.py +0 -0
  65. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/omop/sample_omop_tables.py +0 -0
  66. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/runners/__init__.py +0 -0
  67. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/runners/gpt_runner_util.py +0 -0
  68. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +0 -0
  69. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/runners/hyperparameter_search_util.py +0 -0
  70. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/runners/sample_packing_trainer.py +0 -0
  71. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/simulations/__init__.py +0 -0
  72. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/simulations/generate_plots.py +0 -0
  73. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/simulations/run_simulation.sh +0 -0
  74. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/simulations/time_embedding_simulation.py +0 -0
  75. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/simulations/time_token_simulation.py +0 -0
  76. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/time_to_event/__init__.py +0 -0
  77. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/time_to_event/config/1_year_cabg.yaml +0 -0
  78. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/time_to_event/config/30_day_readmission.yaml +0 -0
  79. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +0 -0
  80. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/time_to_event/config/t2dm_hf.yaml +0 -0
  81. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/time_to_event/time_to_event_model.py +0 -0
  82. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/time_to_event/time_to_event_prediction.py +0 -0
  83. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/time_to_event/time_to_event_utils.py +0 -0
  84. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/tools/__init__.py +0 -0
  85. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/tools/ehrshot_benchmark.py +0 -0
  86. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/tools/generate_causal_patient_split_by_age.py +0 -0
  87. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/tools/generate_pretrained_embeddings.py +0 -0
  88. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/tools/linear_prob/__init__.py +0 -0
  89. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +0 -0
  90. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/tools/merge_synthetic_real_dataasets.py +0 -0
  91. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt/tools/upload_omop_tables.py +0 -0
  92. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt.egg-info/dependency_links.txt +0 -0
  93. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt.egg-info/requires.txt +0 -0
  94. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/src/cehrgpt.egg-info/top_level.txt +0 -0
  95. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/__init__.py +0 -0
  96. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/integration_tests/__init__.py +0 -0
  97. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/integration_tests/runners/__init__.py +0 -0
  98. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/integration_tests/runners/hf_cehrgpt_pretrain_runner_test.py +0 -0
  99. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/integration_tests/runners/hf_cehrgpt_pretrain_sample_packing_runner_test.py +0 -0
  100. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/integration_tests/runners/hf_cehrgpt_pretrain_sfm_runner_test.py +0 -0
  101. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/unit_tests/__init__.py +0 -0
  102. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/unit_tests/gpt_utils_test.py +0 -0
  103. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/unit_tests/models/__init__.py +0 -0
  104. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/unit_tests/models/model_utils_test.py +0 -0
  105. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/unit_tests/models/tokenization/__init__.py +0 -0
  106. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/unit_tests/models/tokenization/create_bins_with_spline_test.py +0 -0
  107. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/unit_tests/models/tokenization/create_sample_from_bins_test.py +0 -0
  108. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/unit_tests/numeric_concept_statistics_test.py +0 -0
  109. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/unit_tests/runners/__init__.py +0 -0
  110. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/unit_tests/runners/hf_cehrgpt_finetune_runner_test.py +0 -0
  111. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/unit_tests/tokenization_test.py +0 -0
  112. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/unit_tests/tools/__init__.py +0 -0
  113. {cehrgpt-0.1.1 → cehrgpt-0.1.2}/tests/unit_tests/tools/upload_omop_tables_test.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cehrgpt
3
- Version: 0.1.1
3
+ Version: 0.1.2
4
4
  Summary: CEHR-GPT: Generating Electronic Health Records with Chronological Patient Timelines
5
5
  Author-email: Chao Pang <chaopang229@gmail.com>, Xinzhuo Jiang <xj2193@cumc.columbia.edu>, Krishna Kalluri <kk3326@cumc.columbia.edu>, Elise Minto <em3697@cumc.columbia.edu>, Jason Patterson <jp3477@cumc.columbia.edu>, Nishanth Parameshwar Pavinkurve <np2689@cumc.columbia.edu>, Karthik Natarajan <kn2174@cumc.columbia.edu>
6
6
  License: MIT License
@@ -105,6 +105,100 @@ sh scripts/omop_pipeline.sh \
105
105
  $OMOP_VOCAB_DIR
106
106
  ```
107
107
 
108
+ # MEDS Support
109
+
110
+ This section demonstrates how to pretrain CEHR-GPT using MIMIC-IV data in the MEDS (Medical Event Data Standard) format.
111
+
112
+ ## Prerequisites
113
+
114
+ Set up the required environment variables before beginning:
115
+
116
+ ```bash
117
+ export CEHR_GPT_MODEL_DIR="" # Path to CEHR-GPT model directory
118
+ export MEDS_DIR="" # Path to MEDS data directory
119
+ export MEDS_READER_DIR="" # Path to MEDS reader output directory
120
+ ```
121
+
122
+ ## Step 1: Create MIMIC MEDS Data
123
+
124
+ Transform your MIMIC files into MEDS format by following the instructions in the [MEDS_transforms](https://github.com/mmcdermott/MEDS_transforms/) repository.
125
+
126
+ ## Step 2: Create the MEDS Reader
127
+
128
+ Convert the MEDS data for use with CEHR-GPT:
129
+
130
+ ```bash
131
+ meds_reader_convert $MEDS_DIR $MEDS_READER_DIR --num_threads 10
132
+ ```
133
+
134
+ ## Step 3: Pretrain CEHR-GPT
135
+
136
+ Run the pretraining process using the prepared MEDS data:
137
+
138
+ ```bash
139
+ python -u -m cehrgpt.runners.hf_cehrgpt_pretrain_runner \
140
+ --model_name_or_path $CEHR_GPT_MODEL_DIR \
141
+ --tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
142
+ --output_dir $CEHR_GPT_MODEL_DIR \
143
+ --data_folder $MEDS_READER_DIR \
144
+ --dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
145
+ --do_train true --seed 42 \
146
+ --dataloader_num_workers 16 --dataloader_prefetch_factor 8 \
147
+ --hidden_size 768 --num_hidden_layers 14 --max_position_embeddings 8192 \
148
+ --evaluation_strategy epoch --save_strategy epoch \
149
+ --sample_packing --max_tokens_per_batch 16384 \
150
+ --warmup_steps 500 --weight_decay 0.01 \
151
+ --num_train_epochs 50 --learning_rate 0.0002 \
152
+ --use_early_stopping --early_stopping_threshold 0.001 \
153
+ --is_data_in_meds --inpatient_att_function_type day \
154
+ --att_function_type day --include_inpatient_hour_token \
155
+ --include_auxiliary_token --include_demographic_prompt \
156
+ --meds_to_cehrbert_conversion_type "MedsToBertMimic4"
157
+ ```
158
+
159
+ ## Step 4: Generate MEDS Trajectories
160
+
161
+ ### Environment Setup for Trajectory Generation
162
+
163
+ Configure additional environment variables for trajectory generation with task labels (`subject_id`, `prediction_time`, `boolean_value` [optional]):
164
+
165
+ ```bash
166
+ # MEDS_LABEL_COHORT_DIR must contain a set of parquet files
167
+ export MEDS_LABEL_COHORT_DIR="" # Path to cohort labels directory
168
+ export MEDS_TRAJECTORY_DIR="" # Path for trajectory output
169
+ ```
170
+
171
+ ### Generate Trajectories
172
+
173
+ Create synthetic patient trajectories using the trained model:
174
+
175
+ > **Important:** The total sequence length (`generation_input_length` + `generation_max_new_tokens`) cannot exceed the `max_position_embeddings` value (8192) defined during pretraining.
176
+
177
+ ```bash
178
+ python -u -m cehrgpt.generation.cehrgpt_conditional_generation \
179
+ --cohort_folder $MEDS_LABEL_COHORT_DIR \
180
+ --data_folder $MEDS_READER_DIR \
181
+ --dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
182
+ --model_name_or_path $CEHR_GPT_MODEL_DIR \
183
+ --tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
184
+ --output_dir $MEDS_TRAJECTORY_DIR \
185
+ --per_device_eval_batch_size 16 \
186
+ --num_of_trajectories_per_sample 2 \
187
+ --generation_input_length 4096 \
188
+ --generation_max_new_tokens 4096 \
189
+ --is_data_in_meds \
190
+ --att_function_type day --inpatient_att_function_type day \
191
+ --meds_to_cehrbert_conversion_type MedsToBertMimic4 \
192
+ --include_auxiliary_token --include_demographic_prompt \
193
+ --include_inpatient_hour_token
194
+ ```
195
+
196
+ ### Parameters Explanation
197
+
198
+ - `generation_input_length`: Controls the length of input context for generation
199
+ - `generation_max_new_tokens`: Maximum number of new tokens to generate
200
+ - `num_of_trajectories_per_sample`: Number of trajectories to generate per patient sample
201
+
108
202
  ## Citation
109
203
  ```
110
204
  @article{cehrgpt2024,
@@ -0,0 +1,174 @@
1
+ # CEHRGPT
2
+
3
+ [![PyPI - Version](https://img.shields.io/pypi/v/cehrgpt)](https://pypi.org/project/cehrgpt/)
4
+ ![Python](https://img.shields.io/badge/-Python_3.11-blue?logo=python&logoColor=white)
5
+ [![tests](https://github.com/knatarajan-lab/cehrgpt/actions/workflows/tests.yaml/badge.svg)](https://github.com/knatarajan-lab/cehrgpt/actions/workflows/tests.yaml)
6
+ [![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray)](https://github.com/knatarajan-lab/cehrgpt/blob/main/LICENSE)
7
+ [![contributors](https://img.shields.io/github/contributors/knatarajan-lab/cehrgpt.svg)](https://github.com/knatarajan-lab/cehrgpt/graphs/contributors)
8
+
9
+ ## Description
10
+ CEHRGPT is a synthetic data generation model developed to handle structured electronic health records (EHR) with enhanced privacy and reliability. It leverages state-of-the-art natural language processing techniques to create realistic, anonymized patient data that can be used for research and development without compromising patient privacy.
11
+
12
+ ## Features
13
+ - **Synthetic Patient Data Generation**: Generates comprehensive patient profiles including demographics, medical history, treatment courses, and outcomes.
14
+ - **Privacy-Preserving**: Implements techniques to ensure the generated data does not reveal identifiable information.
15
+ - **Compatibility with OMOP**: Fully compatible with the OMOP common data model, allowing seamless integration with existing healthcare data systems.
16
+ - **Extensible**: Designed to be adaptable to new datasets and different EHR systems.
17
+
18
+ ## Installation
19
+ To install CEHRGPT, clone this repository and install the required dependencies.
20
+
21
+ ```bash
22
+ git clone https://github.com/knatarajan-lab/cehrgpt.git
23
+ cd cehrgpt
24
+ pip install .
25
+ ```
26
+
27
+ ## Pretrain
28
+ Pretrain cehrgpt using the Hugging Face trainer, the parameters can be found in the sample configuration yaml
29
+ ```bash
30
+ mkdir test_results
31
+ # This is NOT required when streaming is set to true
32
+ mkdir test_dataset_prepared
33
+ python -u -m cehrgpt.runners.hf_cehrgpt_pretrain_runner sample_configs/cehrgpt_pretrain_sample_config.yaml
34
+ ```
35
+
36
+ ## Generate synthetic sequences
37
+ Generate synthetic sequences using the trained model
38
+ ```bash
39
+ export TRANSFORMERS_VERBOSITY=info
40
+ export CUDA_VISIBLE_DEVICES="0"
41
+ python -u -m cehrgpt.generation.generate_batch_hf_gpt_sequence \
42
+ --model_folder test_results \
43
+ --tokenizer_folder test_results \
44
+ --output_folder test_results \
45
+ --num_of_patients 128 \
46
+ --batch_size 32 \
47
+ --buffer_size 128 \
48
+ --context_window 1024 \
49
+ --sampling_strategy TopPStrategy \
50
+ --top_p 1.0 --temperature 1.0 --repetition_penalty 1.0 \
51
+ --epsilon_cutoff 0.00 \
52
+ --demographic_data_path sample_data/pretrain
53
+ ```
54
+
55
+ ## Convert synthetic sequences to OMOP
56
+ ```bash
57
+ # omop converter requires the OHDSI vocabulary
58
+ export OMOP_VOCAB_DIR = ""
59
+ # the omop derived tables need to be built using pyspark
60
+ export SPARK_WORKER_INSTANCES="1"
61
+ export SPARK_WORKER_CORES="8"
62
+ export SPARK_EXECUTOR_CORES="2"
63
+ export SPARK_DRIVER_MEMORY="2g"
64
+ export SPARK_EXECUTOR_MEMORY="2g"
65
+
66
+ # Convert the sequences, create the omop derived tables
67
+ sh scripts/omop_pipeline.sh \
68
+ test_results/top_p10000/generated_sequences/ \
69
+ test_results/top_p10000/restored_omop/ \
70
+ $OMOP_VOCAB_DIR
71
+ ```
72
+
73
+ # MEDS Support
74
+
75
+ This section demonstrates how to pretrain CEHR-GPT using MIMIC-IV data in the MEDS (Medical Event Data Standard) format.
76
+
77
+ ## Prerequisites
78
+
79
+ Set up the required environment variables before beginning:
80
+
81
+ ```bash
82
+ export CEHR_GPT_MODEL_DIR="" # Path to CEHR-GPT model directory
83
+ export MEDS_DIR="" # Path to MEDS data directory
84
+ export MEDS_READER_DIR="" # Path to MEDS reader output directory
85
+ ```
86
+
87
+ ## Step 1: Create MIMIC MEDS Data
88
+
89
+ Transform your MIMIC files into MEDS format by following the instructions in the [MEDS_transforms](https://github.com/mmcdermott/MEDS_transforms/) repository.
90
+
91
+ ## Step 2: Create the MEDS Reader
92
+
93
+ Convert the MEDS data for use with CEHR-GPT:
94
+
95
+ ```bash
96
+ meds_reader_convert $MEDS_DIR $MEDS_READER_DIR --num_threads 10
97
+ ```
98
+
99
+ ## Step 3: Pretrain CEHR-GPT
100
+
101
+ Run the pretraining process using the prepared MEDS data:
102
+
103
+ ```bash
104
+ python -u -m cehrgpt.runners.hf_cehrgpt_pretrain_runner \
105
+ --model_name_or_path $CEHR_GPT_MODEL_DIR \
106
+ --tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
107
+ --output_dir $CEHR_GPT_MODEL_DIR \
108
+ --data_folder $MEDS_READER_DIR \
109
+ --dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
110
+ --do_train true --seed 42 \
111
+ --dataloader_num_workers 16 --dataloader_prefetch_factor 8 \
112
+ --hidden_size 768 --num_hidden_layers 14 --max_position_embeddings 8192 \
113
+ --evaluation_strategy epoch --save_strategy epoch \
114
+ --sample_packing --max_tokens_per_batch 16384 \
115
+ --warmup_steps 500 --weight_decay 0.01 \
116
+ --num_train_epochs 50 --learning_rate 0.0002 \
117
+ --use_early_stopping --early_stopping_threshold 0.001 \
118
+ --is_data_in_meds --inpatient_att_function_type day \
119
+ --att_function_type day --include_inpatient_hour_token \
120
+ --include_auxiliary_token --include_demographic_prompt \
121
+ --meds_to_cehrbert_conversion_type "MedsToBertMimic4"
122
+ ```
123
+
124
+ ## Step 4: Generate MEDS Trajectories
125
+
126
+ ### Environment Setup for Trajectory Generation
127
+
128
+ Configure additional environment variables for trajectory generation with task labels (`subject_id`, `prediction_time`, `boolean_value` [optional]):
129
+
130
+ ```bash
131
+ # MEDS_LABEL_COHORT_DIR must contain a set of parquet files
132
+ export MEDS_LABEL_COHORT_DIR="" # Path to cohort labels directory
133
+ export MEDS_TRAJECTORY_DIR="" # Path for trajectory output
134
+ ```
135
+
136
+ ### Generate Trajectories
137
+
138
+ Create synthetic patient trajectories using the trained model:
139
+
140
+ > **Important:** The total sequence length (`generation_input_length` + `generation_max_new_tokens`) cannot exceed the `max_position_embeddings` value (8192) defined during pretraining.
141
+
142
+ ```bash
143
+ python -u -m cehrgpt.generation.cehrgpt_conditional_generation \
144
+ --cohort_folder $MEDS_LABEL_COHORT_DIR \
145
+ --data_folder $MEDS_READER_DIR \
146
+ --dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
147
+ --model_name_or_path $CEHR_GPT_MODEL_DIR \
148
+ --tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
149
+ --output_dir $MEDS_TRAJECTORY_DIR \
150
+ --per_device_eval_batch_size 16 \
151
+ --num_of_trajectories_per_sample 2 \
152
+ --generation_input_length 4096 \
153
+ --generation_max_new_tokens 4096 \
154
+ --is_data_in_meds \
155
+ --att_function_type day --inpatient_att_function_type day \
156
+ --meds_to_cehrbert_conversion_type MedsToBertMimic4 \
157
+ --include_auxiliary_token --include_demographic_prompt \
158
+ --include_inpatient_hour_token
159
+ ```
160
+
161
+ ### Parameters Explanation
162
+
163
+ - `generation_input_length`: Controls the length of input context for generation
164
+ - `generation_max_new_tokens`: Maximum number of new tokens to generate
165
+ - `num_of_trajectories_per_sample`: Number of trajectories to generate per patient sample
166
+
167
+ ## Citation
168
+ ```
169
+ @article{cehrgpt2024,
170
+ title={CEHRGPT: Synthetic Data Generation for Electronic Health Records},
171
+ author={Natarajan, K and others},
172
+ journal={arXiv preprint arXiv:2402.04400},
173
+ year={2024}
174
+ }
@@ -162,6 +162,22 @@ class CehrGptDataCollator:
162
162
  f"batch['input_ids']: {batch['input_ids']} "
163
163
  )
164
164
 
165
+ if "epoch_times" in examples[0]:
166
+ batch_epoch_times = [
167
+ self._try_reverse_tensor(
168
+ self._convert_to_tensor(example["epoch_times"])
169
+ )
170
+ for example in examples
171
+ ]
172
+ # Pad sequences to the max length in the batch
173
+ batch["epoch_times"] = self._try_reverse_tensor(
174
+ pad_sequence(
175
+ batch_epoch_times,
176
+ batch_first=True,
177
+ padding_value=0,
178
+ ).to(torch.float32)
179
+ )
180
+
165
181
  if "position_ids" in examples[0]:
166
182
  batch_position_ids = [
167
183
  self._try_reverse_tensor(
@@ -663,7 +679,9 @@ class CehrGptDataCollator:
663
679
 
664
680
  # Subtract one for the [END] token when sample_packing is not enabled
665
681
  new_max_length = (
666
- max_length_allowed if sample_packing else max_length_allowed - 1
682
+ max_length_allowed - 1
683
+ if not sample_packing and self.pretraining
684
+ else max_length_allowed
667
685
  )
668
686
 
669
687
  if self.include_ttv_prediction:
@@ -685,13 +703,20 @@ class CehrGptDataCollator:
685
703
 
686
704
  # Return the record directly if the actual sequence length is less than the max sequence
687
705
  if seq_length <= new_max_length:
688
- if not sample_packing:
706
+ if not sample_packing and self.pretraining:
689
707
  record["input_ids"] = torch.concat(
690
708
  [
691
709
  self._convert_to_tensor(record["input_ids"]),
692
710
  self._convert_to_tensor([eos_token]),
693
711
  ]
694
712
  )
713
+ if "epoch_times" in record:
714
+ record["epoch_times"] = torch.concat(
715
+ [
716
+ self._convert_to_tensor(record["epoch_times"]),
717
+ self._convert_to_tensor([record["epoch_times"][-1]]),
718
+ ]
719
+ )
695
720
  if self.include_values:
696
721
  record["value_indicators"] = torch.concat(
697
722
  [
@@ -727,6 +752,10 @@ class CehrGptDataCollator:
727
752
  record["input_ids"] = self._convert_to_tensor(
728
753
  record["input_ids"][start_index : end_index + 1]
729
754
  )
755
+ if "epoch_times" in record:
756
+ record["epoch_times"] = self._convert_to_tensor(
757
+ record["epoch_times"][start_index : end_index + 1]
758
+ )
730
759
  if self.include_values:
731
760
  record["value_indicators"] = self._convert_to_tensor(
732
761
  record["value_indicators"][start_index : end_index + 1]
@@ -760,6 +789,11 @@ class CehrGptDataCollator:
760
789
  if sample_packing and "position_ids" in record:
761
790
  record["position_ids"] = record["position_ids"][0:end_index]
762
791
 
792
+ if "epoch_times" in record:
793
+ record["epoch_times"] = self._convert_to_tensor(
794
+ record["epoch_times"][0:end_index]
795
+ )
796
+
763
797
  if self.include_values:
764
798
  record["value_indicators"] = self._convert_to_tensor(
765
799
  record["value_indicators"][0:end_index]
@@ -792,6 +826,17 @@ class CehrGptDataCollator:
792
826
  ),
793
827
  ]
794
828
  )
829
+ if "epoch_times" in record:
830
+ record["epoch_times"] = torch.concat(
831
+ [
832
+ torch.zeros(
833
+ [record["epoch_times"][0]], dtype=torch.float32
834
+ ),
835
+ self._convert_to_tensor(
836
+ record["epoch_times"][token_index:seq_length]
837
+ ),
838
+ ]
839
+ )
795
840
  if self.include_values:
796
841
  record["value_indicators"] = torch.concat(
797
842
  [
@@ -830,7 +875,7 @@ class CehrGptDataCollator:
830
875
  )
831
876
  break
832
877
  else:
833
- start_index = seq_length - new_max_length
878
+ start_index = max(seq_length - new_max_length, 0)
834
879
  end_index = seq_length
835
880
  for i in range(start_index, end_index):
836
881
  current_token = record["input_ids"][i]
@@ -842,6 +887,11 @@ class CehrGptDataCollator:
842
887
  ]
843
888
  if sample_packing and "position_ids" in record:
844
889
  record["position_ids"] = record["position_ids"][i:end_index]
890
+
891
+ if "epoch_times" in record:
892
+ record["epoch_times"] = self._convert_to_tensor(
893
+ record["epoch_times"][i:end_index]
894
+ )
845
895
  if self.include_values:
846
896
  record["value_indicators"] = record["value_indicators"][
847
897
  i:end_index
@@ -863,6 +913,10 @@ class CehrGptDataCollator:
863
913
  ]
864
914
  if sample_packing and "position_ids" in record:
865
915
  record["position_ids"] = record["position_ids"][-new_max_length:]
916
+ if "epoch_times" in record:
917
+ record["epoch_times"] = self._convert_to_tensor(
918
+ record["epoch_times"][-new_max_length:]
919
+ )
866
920
  if self.include_values:
867
921
  record["value_indicators"] = record["value_indicators"][
868
922
  -new_max_length:
@@ -873,36 +927,6 @@ class CehrGptDataCollator:
873
927
  -new_max_length:
874
928
  ]
875
929
 
876
- if not sample_packing:
877
- # Finally we add the end token to the end of the sequence
878
- record["input_ids"] = torch.concat(
879
- [
880
- self._convert_to_tensor(record["input_ids"]),
881
- self._convert_to_tensor([eos_token]),
882
- ]
883
- )
884
- if self.include_values:
885
- record["value_indicators"] = torch.concat(
886
- [
887
- self._convert_to_tensor(record["value_indicators"]),
888
- self._convert_to_tensor([False]),
889
- ]
890
- ).to(torch.bool)
891
- record["values"] = torch.concat(
892
- [
893
- self._convert_to_tensor(record["values"]),
894
- self._convert_to_tensor(
895
- [self.tokenizer.pad_value_token_id]
896
- ),
897
- ]
898
- )
899
- if self.include_ttv_prediction:
900
- record["time_to_visits"] = torch.concat(
901
- [
902
- record["time_to_visits"],
903
- self._convert_to_tensor([-100.0]),
904
- ]
905
- )
906
930
  return record
907
931
 
908
932
 
@@ -21,7 +21,6 @@ from cehrbert_data.const.artificial_tokens import (
21
21
  DISCHARGE_UNKNOWN_TOKEN,
22
22
  GENDER_UNKNOWN_TOKEN,
23
23
  RACE_UNKNOWN_TOKEN,
24
- VISIT_UNKNOWN_TOKEN,
25
24
  )
26
25
  from cehrbert_data.const.common import NA
27
26
  from cehrbert_data.decorators.patient_event_decorator_base import get_att_function
@@ -47,10 +46,16 @@ CEHRGPT_COLUMNS = [
47
46
  ]
48
47
 
49
48
 
50
- def convert_date_to_posix_time(index_date: datetime.date) -> float:
51
- return datetime.datetime.combine(
52
- index_date, datetime.datetime.min.time()
53
- ).timestamp()
49
+ def convert_date_to_posix_time(index_date: Union[datetime.date, int, float]) -> float:
50
+ if isinstance(index_date, datetime.date):
51
+ return (
52
+ datetime.datetime.combine(index_date, datetime.datetime.min.time())
53
+ .replace(tzinfo=datetime.timezone.utc)
54
+ .timestamp()
55
+ )
56
+ elif isinstance(index_date, datetime.datetime):
57
+ return index_date.replace(tzinfo=datetime.timezone.utc).timestamp()
58
+ return index_date
54
59
 
55
60
 
56
61
  class DatasetMappingDecorator(DatasetMapping):
@@ -128,7 +133,9 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
128
133
  cehrgpt_record["concept_as_values"].append(concept_as_value)
129
134
  cehrgpt_record["units"].append(unit)
130
135
  cehrgpt_record["is_numeric_types"].append(is_numeric_type)
131
- cehrgpt_record["epoch_times"].append(time.timestamp())
136
+ cehrgpt_record["epoch_times"].append(
137
+ time.replace(tzinfo=datetime.timezone.utc).timestamp()
138
+ )
132
139
 
133
140
  def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
134
141
  cehrgpt_record = {
@@ -360,7 +367,9 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
360
367
  cehrgpt_record["num_of_visits"] = len(visits)
361
368
 
362
369
  if record.get("index_date", None) is not None:
363
- cehrgpt_record["index_date"] = record["index_date"]
370
+ cehrgpt_record["index_date"] = (
371
+ record["index_date"].replace(tzinfo=datetime.timezone.utc).timestamp()
372
+ )
364
373
  if record.get("label", None) is not None:
365
374
  cehrgpt_record["label"] = record["label"]
366
375
  if record.get("age_at_index", None) is not None:
@@ -529,9 +538,13 @@ class ExtractTokenizedSequenceDataMapping:
529
538
  prediction_start_end_times = [
530
539
  (
531
540
  self._calculate_prediction_start_time(
532
- prediction_time_label_map["index_date"].timestamp()
541
+ prediction_time_label_map["index_date"]
542
+ .replace(tzinfo=datetime.timezone.utc)
543
+ .timestamp()
533
544
  ),
534
- prediction_time_label_map["index_date"].timestamp(),
545
+ prediction_time_label_map["index_date"]
546
+ .replace(tzinfo=datetime.timezone.utc)
547
+ .timestamp(),
535
548
  prediction_time_label_map["label"],
536
549
  )
537
550
  for prediction_time_label_map in prediction_times