cehrgpt 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/data/hf_cehrgpt_dataset.py +24 -4
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +260 -84
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +279 -2
- cehrgpt/data/sample_packing_sampler.py +151 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
- cehrgpt/generation/omop_converter_batch.py +3 -0
- cehrgpt/models/config.py +10 -0
- cehrgpt/models/hf_cehrgpt.py +244 -73
- cehrgpt/models/tokenization_hf_cehrgpt.py +6 -2
- cehrgpt/runners/data_utils.py +243 -0
- cehrgpt/runners/gpt_runner_util.py +0 -10
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +154 -260
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +250 -90
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +46 -0
- cehrgpt/runners/hyperparameter_search_util.py +4 -1
- cehrgpt/runners/sample_packing_trainer.py +168 -0
- cehrgpt/simulations/__init__.py +0 -0
- cehrgpt/simulations/generate_plots.py +95 -0
- cehrgpt/simulations/run_simulation.sh +24 -0
- cehrgpt/simulations/time_embedding_simulation.py +250 -0
- cehrgpt/simulations/time_token_simulation.py +177 -0
- cehrgpt/tools/generate_causal_patient_split_by_age.py +146 -0
- cehrgpt/tools/linear_prob/__init__.py +0 -0
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +467 -0
- cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/METADATA +57 -9
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/RECORD +30 -18
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/WHEEL +1 -1
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info/licenses}/LICENSE +0 -0
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: cehrgpt
|
3
|
-
Version: 0.0
|
3
|
+
Version: 0.1.0
|
4
4
|
Summary: CEHR-GPT: Generating Electronic Health Records with Chronological Patient Timelines
|
5
5
|
Author-email: Chao Pang <chaopang229@gmail.com>, Xinzhuo Jiang <xj2193@cumc.columbia.edu>, Krishna Kalluri <kk3326@cumc.columbia.edu>, Elise Minto <em3697@cumc.columbia.edu>, Jason Patterson <jp3477@cumc.columbia.edu>, Nishanth Parameshwar Pavinkurve <np2689@cumc.columbia.edu>, Karthik Natarajan <kn2174@cumc.columbia.edu>
|
6
6
|
License: MIT License
|
@@ -12,12 +12,14 @@ Classifier: Programming Language :: Python :: 3
|
|
12
12
|
Requires-Python: >=3.10.0
|
13
13
|
Description-Content-Type: text/markdown
|
14
14
|
License-File: LICENSE
|
15
|
-
Requires-Dist: cehrbert==1.
|
15
|
+
Requires-Dist: cehrbert==1.4.1
|
16
|
+
Requires-Dist: cehrbert_data==0.0.7
|
16
17
|
Requires-Dist: openai==1.54.3
|
17
18
|
Requires-Dist: optuna==4.0.0
|
18
|
-
Requires-Dist: transformers==4.
|
19
|
-
Requires-Dist: tokenizers==0.19
|
20
|
-
Requires-Dist:
|
19
|
+
Requires-Dist: transformers==4.44.0
|
20
|
+
Requires-Dist: tokenizers==0.19.0
|
21
|
+
Requires-Dist: peft==0.10.0
|
22
|
+
Requires-Dist: lightgbm
|
21
23
|
Provides-Extra: dev
|
22
24
|
Requires-Dist: pre-commit; extra == "dev"
|
23
25
|
Requires-Dist: pytest; extra == "dev"
|
@@ -28,6 +30,7 @@ Requires-Dist: hypothesis; extra == "dev"
|
|
28
30
|
Requires-Dist: black; extra == "dev"
|
29
31
|
Provides-Extra: flash-attn
|
30
32
|
Requires-Dist: flash_attn; extra == "flash-attn"
|
33
|
+
Dynamic: license-file
|
31
34
|
|
32
35
|
# CEHRGPT
|
33
36
|
|
@@ -50,11 +53,57 @@ CEHRGPT is a synthetic data generation model developed to handle structured elec
|
|
50
53
|
To install CEHRGPT, clone this repository and install the required dependencies.
|
51
54
|
|
52
55
|
```bash
|
53
|
-
git clone https://github.com/knatarajan-lab/cehrgpt
|
54
|
-
cd cehrgpt
|
56
|
+
git clone https://github.com/knatarajan-lab/cehrgpt.git
|
57
|
+
cd cehrgpt
|
55
58
|
pip install .
|
56
59
|
```
|
57
60
|
|
61
|
+
## Pretrain
|
62
|
+
Pretrain cehrgpt using the Hugging Face trainer, the parameters can be found in the sample configuration yaml
|
63
|
+
```bash
|
64
|
+
mkdir test_results
|
65
|
+
# This is NOT required when streaming is set to true
|
66
|
+
mkdir test_dataset_prepared
|
67
|
+
python -u -m cehrgpt.runners.hf_cehrgpt_pretrain_runner sample_configs/cehrgpt_pretrain_sample_config.yaml
|
68
|
+
```
|
69
|
+
|
70
|
+
## Generate synthetic sequences
|
71
|
+
Generate synthetic sequences using the trained model
|
72
|
+
```bash
|
73
|
+
export TRANSFORMERS_VERBOSITY=info
|
74
|
+
export CUDA_VISIBLE_DEVICES="0"
|
75
|
+
python -u -m cehrgpt.generation.generate_batch_hf_gpt_sequence \
|
76
|
+
--model_folder test_results \
|
77
|
+
--tokenizer_folder test_results \
|
78
|
+
--output_folder test_results \
|
79
|
+
--num_of_patients 128 \
|
80
|
+
--batch_size 32 \
|
81
|
+
--buffer_size 128 \
|
82
|
+
--context_window 1024 \
|
83
|
+
--sampling_strategy TopPStrategy \
|
84
|
+
--top_p 1.0 --temperature 1.0 --repetition_penalty 1.0 \
|
85
|
+
--epsilon_cutoff 0.00 \
|
86
|
+
--demographic_data_path sample_data/pretrain
|
87
|
+
```
|
88
|
+
|
89
|
+
## Convert synthetic sequences to OMOP
|
90
|
+
```bash
|
91
|
+
# omop converter requires the OHDSI vocabulary
|
92
|
+
export OMOP_VOCAB_DIR = ""
|
93
|
+
# the omop derived tables need to be built using pyspark
|
94
|
+
export SPARK_WORKER_INSTANCES="1"
|
95
|
+
export SPARK_WORKER_CORES="8"
|
96
|
+
export SPARK_EXECUTOR_CORES="2"
|
97
|
+
export SPARK_DRIVER_MEMORY="2g"
|
98
|
+
export SPARK_EXECUTOR_MEMORY="2g"
|
99
|
+
|
100
|
+
# Convert the sequences, create the omop derived tables
|
101
|
+
sh scripts/omop_pipeline.sh \
|
102
|
+
test_results/top_p10000/generated_sequences/ \
|
103
|
+
test_results/top_p10000/restored_omop/ \
|
104
|
+
$OMOP_VOCAB_DIR
|
105
|
+
```
|
106
|
+
|
58
107
|
## Citation
|
59
108
|
```
|
60
109
|
@article{cehrgpt2024,
|
@@ -63,4 +112,3 @@ pip install .
|
|
63
112
|
journal={arXiv preprint arXiv:2402.04400},
|
64
113
|
year={2024}
|
65
114
|
}
|
66
|
-
```
|
@@ -11,21 +11,22 @@ cehrgpt/analysis/privacy/nearest_neighbor_inference.py,sha256=qoJgWW7VsUMzjMGpTa
|
|
11
11
|
cehrgpt/analysis/privacy/reid_inference.py,sha256=Pypd3QJXQNY8VljpnIEa5zeAbTZHMjQOazaL-9VsBGw,13955
|
12
12
|
cehrgpt/analysis/privacy/utils.py,sha256=CRA4H9mPLBjMQGKzZ_x_3ro3tMap-NjsMDVqSOjHSVQ,8226
|
13
13
|
cehrgpt/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
|
-
cehrgpt/data/hf_cehrgpt_dataset.py,sha256=
|
15
|
-
cehrgpt/data/hf_cehrgpt_dataset_collator.py,sha256=
|
16
|
-
cehrgpt/data/hf_cehrgpt_dataset_mapping.py,sha256=
|
14
|
+
cehrgpt/data/hf_cehrgpt_dataset.py,sha256=t9vpN05e--CiKgIlxLP0aLacISnvWWDPXtuFuJi3ksE,3736
|
15
|
+
cehrgpt/data/hf_cehrgpt_dataset_collator.py,sha256=DOvIF4Wzkd8-IO3zpIRZkX1j0IdvefaiSnrDn1YivCk,27912
|
16
|
+
cehrgpt/data/hf_cehrgpt_dataset_mapping.py,sha256=eI8CTk6yJ4DlNJWrNAkEmhWh353NeLqg5rwPpKqKT-U,17308
|
17
|
+
cehrgpt/data/sample_packing_sampler.py,sha256=0uKTbvtXpfS81esy_3epJ88eohyJPK46bfmxhle1fws,5419
|
17
18
|
cehrgpt/generation/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
18
19
|
cehrgpt/generation/chatgpt_generation.py,sha256=SrnLwHLdNtnAOEg36gNjqfoT9yd12iyPgpZffL2AFJo,4428
|
19
|
-
cehrgpt/generation/generate_batch_hf_gpt_sequence.py,sha256
|
20
|
-
cehrgpt/generation/omop_converter_batch.py,sha256
|
20
|
+
cehrgpt/generation/generate_batch_hf_gpt_sequence.py,sha256=uSEh8aMmPD61nGewIaPSkIqm-2AxDjCBiu4cBfxHxU4,11503
|
21
|
+
cehrgpt/generation/omop_converter_batch.py,sha256=-c0AlDVy5pJ5Afhr8ERiCHhoRrEk8ozJi3g0yFdWaMI,25348
|
21
22
|
cehrgpt/generation/omop_entity.py,sha256=Q5Sr0AlyuPAm1FRPfnJO13q-u1fqRgYVHXruZ9g4xNE,19400
|
22
23
|
cehrgpt/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
23
|
-
cehrgpt/models/config.py,sha256=
|
24
|
-
cehrgpt/models/hf_cehrgpt.py,sha256=
|
24
|
+
cehrgpt/models/config.py,sha256=Y3CiXZWniLP9_RlpU80Oe9gjn5leLmTYnNe_fWqfJLQ,10158
|
25
|
+
cehrgpt/models/hf_cehrgpt.py,sha256=3EQIOfa--oz4f8bM8KzbDi98G3XrUEQkox1vmBN001M,83321
|
25
26
|
cehrgpt/models/hf_modeling_outputs.py,sha256=LaWa1jI6BRIKMEjWOy1QUeOfTur5y_p2c-JyuGVTdtw,10301
|
26
27
|
cehrgpt/models/pretrained_embeddings.py,sha256=vLLVs17TLpXRqCVEWQxGGwPHkUJUO7laNTeBuyBK_yk,3238
|
27
28
|
cehrgpt/models/special_tokens.py,sha256=-a7HPJBbdIH0qQ6B3CcRKqvpG6FZlm4nbVPTswGSJ4U,485
|
28
|
-
cehrgpt/models/tokenization_hf_cehrgpt.py,sha256=
|
29
|
+
cehrgpt/models/tokenization_hf_cehrgpt.py,sha256=jjCRqS29IzMnKp40jNOs80UKh2z9lK5S6M02GSB-4mk,42351
|
29
30
|
cehrgpt/omop/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
30
31
|
cehrgpt/omop/condition_era.py,sha256=hPZALz2XaWnro_1bwIYNkI48foOJjueyg3CZ1BliCno,626
|
31
32
|
cehrgpt/omop/observation_period.py,sha256=TRMgv5Ya2RaS2im7oQ6BLC_5JL9EJYNYR62ApxIuHvg,1211
|
@@ -36,11 +37,18 @@ cehrgpt/omop/queries/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hS
|
|
36
37
|
cehrgpt/omop/queries/condition_era.py,sha256=LFB6vBAvshHJxtYIRkl7cfrF0kf7ay0piBKpmHBwrpE,2578
|
37
38
|
cehrgpt/omop/queries/observation_period.py,sha256=fpzr5DMNw-QLoSwp2Iatfch88E3hyhZ75usiIdG3A0U,6410
|
38
39
|
cehrgpt/runners/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
39
|
-
cehrgpt/runners/
|
40
|
-
cehrgpt/runners/
|
41
|
-
cehrgpt/runners/
|
42
|
-
cehrgpt/runners/
|
43
|
-
cehrgpt/runners/
|
40
|
+
cehrgpt/runners/data_utils.py,sha256=ScZZnfXwgXKaMvKgFzdb4vtQ7F_lw97O5uNsFbfsyP4,10620
|
41
|
+
cehrgpt/runners/gpt_runner_util.py,sha256=YJQSRW9Mo4TjXSOUOTf6BUFcs1MGFiXU5T4ztKZcYhU,3485
|
42
|
+
cehrgpt/runners/hf_cehrgpt_finetune_runner.py,sha256=bkPl30Y9CSXBlmMkH-3cA3-aW8XJK36Q-adx___WjkE,26921
|
43
|
+
cehrgpt/runners/hf_cehrgpt_pretrain_runner.py,sha256=ViVa_flEGdk_SO0psMR7ho-o79igsz_l1x80u81WJ3A,23875
|
44
|
+
cehrgpt/runners/hf_gpt_runner_argument_dataclass.py,sha256=VrqgDSiAMfGyHEIodoOg_8LU5O0ndWf9EE0YOKDFKKA,7019
|
45
|
+
cehrgpt/runners/hyperparameter_search_util.py,sha256=pWFmGo9Ezju4YmuZ-ohbAbYB0GGMfIDVUCyvcTxS1iU,9153
|
46
|
+
cehrgpt/runners/sample_packing_trainer.py,sha256=aezX30vxpP1DDcH5hO-yn395NqBKi2Xhb0mFNHi9OBs,7340
|
47
|
+
cehrgpt/simulations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
48
|
+
cehrgpt/simulations/generate_plots.py,sha256=BTZ71r8Kah0PMorkiO3vw55_p_9U1Z8KiD3GsPfaV0s,2520
|
49
|
+
cehrgpt/simulations/run_simulation.sh,sha256=DcJ6B19jIteUO0pZ0Tc21876lB9XxQHFAxlre7MtAzk,795
|
50
|
+
cehrgpt/simulations/time_embedding_simulation.py,sha256=HZ-imXH-bN-QYZN1PAfcERmNtaWIwKjbf0UrZduwCiA,8687
|
51
|
+
cehrgpt/simulations/time_token_simulation.py,sha256=sLg8vVXydvR_zk3BbqyrlA7sDIdhFnS-s5pSKcCilSc,6057
|
44
52
|
cehrgpt/time_to_event/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
45
53
|
cehrgpt/time_to_event/time_to_event_model.py,sha256=tfXa24l_0q1TBZ68BPRrHRC_3KRWYxrWGIv4myJlIb8,8497
|
46
54
|
cehrgpt/time_to_event/time_to_event_prediction.py,sha256=Ajesq2gSsILghWHCTLiiBhcyOCa7m6JPPMdi_xvBlR4,12624
|
@@ -50,11 +58,15 @@ cehrgpt/time_to_event/config/next_visit_type_prediction.yaml,sha256=WMj2ZutEvHKI
|
|
50
58
|
cehrgpt/time_to_event/config/t2dm_hf.yaml,sha256=_oMQzh2eJTYzEaMOpmhAzbX-qmdsKlkORELL6HxOxHo,202
|
51
59
|
cehrgpt/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
52
60
|
cehrgpt/tools/ehrshot_benchmark.py,sha256=E-m_5srlYEw7Y7i9twIJWDvrkwNlop-6yZB-80FZid0,2667
|
61
|
+
cehrgpt/tools/generate_causal_patient_split_by_age.py,sha256=dmHiPAL_kR1WrhRteIiHH9dwMtMi3PVl8jXm2O06_gI,4177
|
53
62
|
cehrgpt/tools/generate_pretrained_embeddings.py,sha256=lhFSacGv8bMld6qigKZN8Op8eXpFi0DsJuQbWKOWXqI,4160
|
54
63
|
cehrgpt/tools/merge_synthetic_real_dataasets.py,sha256=O1dbQ32Le0t15fwymwAh9mfNVLEWuFwW53DNvESrWbY,7589
|
55
64
|
cehrgpt/tools/upload_omop_tables.py,sha256=vdBAbkeAsGPA4NsyhNjelPVj3gS8yzmS1sKNM1Qk96g,3791
|
56
|
-
cehrgpt
|
57
|
-
cehrgpt
|
58
|
-
cehrgpt
|
59
|
-
cehrgpt-0.0.
|
60
|
-
cehrgpt-0.0.
|
65
|
+
cehrgpt/tools/linear_prob/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
66
|
+
cehrgpt/tools/linear_prob/compute_cehrgpt_features.py,sha256=jVgAmBrZKp7ABfqKkzwV5Vl_G9jDCjPl98NSVmSwHpE,19291
|
67
|
+
cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py,sha256=w0UvzMKYGenN_KDVnbzutmy8IPLUxW5hPvpKKxDSL5U,5820
|
68
|
+
cehrgpt-0.1.0.dist-info/licenses/LICENSE,sha256=LOfC32zkfUIdGm8e_098jPbt8OHKtNWymDzxn2pA9Zk,1093
|
69
|
+
cehrgpt-0.1.0.dist-info/METADATA,sha256=V02vsptjJRD_bybXVRFXPrJa-By9CX4j-oAA3EfXFq4,4933
|
70
|
+
cehrgpt-0.1.0.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
71
|
+
cehrgpt-0.1.0.dist-info/top_level.txt,sha256=akNCJBbMSLV8nkOzdVzdy13hMJ5CIQURnAS_YYEDVwA,17
|
72
|
+
cehrgpt-0.1.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|