cehrgpt 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. cehrgpt/data/hf_cehrgpt_dataset.py +24 -4
  2. cehrgpt/data/hf_cehrgpt_dataset_collator.py +260 -84
  3. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +279 -2
  4. cehrgpt/data/sample_packing_sampler.py +151 -0
  5. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
  6. cehrgpt/generation/omop_converter_batch.py +3 -0
  7. cehrgpt/models/config.py +10 -0
  8. cehrgpt/models/hf_cehrgpt.py +244 -73
  9. cehrgpt/models/tokenization_hf_cehrgpt.py +6 -2
  10. cehrgpt/runners/data_utils.py +243 -0
  11. cehrgpt/runners/gpt_runner_util.py +0 -10
  12. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +154 -260
  13. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +250 -90
  14. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +46 -0
  15. cehrgpt/runners/hyperparameter_search_util.py +4 -1
  16. cehrgpt/runners/sample_packing_trainer.py +168 -0
  17. cehrgpt/simulations/__init__.py +0 -0
  18. cehrgpt/simulations/generate_plots.py +95 -0
  19. cehrgpt/simulations/run_simulation.sh +24 -0
  20. cehrgpt/simulations/time_embedding_simulation.py +250 -0
  21. cehrgpt/simulations/time_token_simulation.py +177 -0
  22. cehrgpt/tools/generate_causal_patient_split_by_age.py +146 -0
  23. cehrgpt/tools/linear_prob/__init__.py +0 -0
  24. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +467 -0
  25. cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
  26. {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/METADATA +57 -9
  27. {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/RECORD +30 -18
  28. {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/WHEEL +1 -1
  29. {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info/licenses}/LICENSE +0 -0
  30. {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: cehrgpt
3
- Version: 0.0.1
3
+ Version: 0.1.0
4
4
  Summary: CEHR-GPT: Generating Electronic Health Records with Chronological Patient Timelines
5
5
  Author-email: Chao Pang <chaopang229@gmail.com>, Xinzhuo Jiang <xj2193@cumc.columbia.edu>, Krishna Kalluri <kk3326@cumc.columbia.edu>, Elise Minto <em3697@cumc.columbia.edu>, Jason Patterson <jp3477@cumc.columbia.edu>, Nishanth Parameshwar Pavinkurve <np2689@cumc.columbia.edu>, Karthik Natarajan <kn2174@cumc.columbia.edu>
6
6
  License: MIT License
@@ -12,12 +12,14 @@ Classifier: Programming Language :: Python :: 3
12
12
  Requires-Python: >=3.10.0
13
13
  Description-Content-Type: text/markdown
14
14
  License-File: LICENSE
15
- Requires-Dist: cehrbert==1.3.3
15
+ Requires-Dist: cehrbert==1.4.1
16
+ Requires-Dist: cehrbert_data==0.0.7
16
17
  Requires-Dist: openai==1.54.3
17
18
  Requires-Dist: optuna==4.0.0
18
- Requires-Dist: transformers==4.40.0
19
- Requires-Dist: tokenizers==0.19
20
- Requires-Dist: trl==0.11.4
19
+ Requires-Dist: transformers==4.44.0
20
+ Requires-Dist: tokenizers==0.19.0
21
+ Requires-Dist: peft==0.10.0
22
+ Requires-Dist: lightgbm
21
23
  Provides-Extra: dev
22
24
  Requires-Dist: pre-commit; extra == "dev"
23
25
  Requires-Dist: pytest; extra == "dev"
@@ -28,6 +30,7 @@ Requires-Dist: hypothesis; extra == "dev"
28
30
  Requires-Dist: black; extra == "dev"
29
31
  Provides-Extra: flash-attn
30
32
  Requires-Dist: flash_attn; extra == "flash-attn"
33
+ Dynamic: license-file
31
34
 
32
35
  # CEHRGPT
33
36
 
@@ -50,11 +53,57 @@ CEHRGPT is a synthetic data generation model developed to handle structured elec
50
53
  To install CEHRGPT, clone this repository and install the required dependencies.
51
54
 
52
55
  ```bash
53
- git clone https://github.com/knatarajan-lab/cehrgpt-public.git
54
- cd cehrgpt-public
56
+ git clone https://github.com/knatarajan-lab/cehrgpt.git
57
+ cd cehrgpt
55
58
  pip install .
56
59
  ```
57
60
 
61
+ ## Pretrain
62
+ Pretrain cehrgpt using the Hugging Face trainer, the parameters can be found in the sample configuration yaml
63
+ ```bash
64
+ mkdir test_results
65
+ # This is NOT required when streaming is set to true
66
+ mkdir test_dataset_prepared
67
+ python -u -m cehrgpt.runners.hf_cehrgpt_pretrain_runner sample_configs/cehrgpt_pretrain_sample_config.yaml
68
+ ```
69
+
70
+ ## Generate synthetic sequences
71
+ Generate synthetic sequences using the trained model
72
+ ```bash
73
+ export TRANSFORMERS_VERBOSITY=info
74
+ export CUDA_VISIBLE_DEVICES="0"
75
+ python -u -m cehrgpt.generation.generate_batch_hf_gpt_sequence \
76
+ --model_folder test_results \
77
+ --tokenizer_folder test_results \
78
+ --output_folder test_results \
79
+ --num_of_patients 128 \
80
+ --batch_size 32 \
81
+ --buffer_size 128 \
82
+ --context_window 1024 \
83
+ --sampling_strategy TopPStrategy \
84
+ --top_p 1.0 --temperature 1.0 --repetition_penalty 1.0 \
85
+ --epsilon_cutoff 0.00 \
86
+ --demographic_data_path sample_data/pretrain
87
+ ```
88
+
89
+ ## Convert synthetic sequences to OMOP
90
+ ```bash
91
+ # omop converter requires the OHDSI vocabulary
92
+ export OMOP_VOCAB_DIR = ""
93
+ # the omop derived tables need to be built using pyspark
94
+ export SPARK_WORKER_INSTANCES="1"
95
+ export SPARK_WORKER_CORES="8"
96
+ export SPARK_EXECUTOR_CORES="2"
97
+ export SPARK_DRIVER_MEMORY="2g"
98
+ export SPARK_EXECUTOR_MEMORY="2g"
99
+
100
+ # Convert the sequences, create the omop derived tables
101
+ sh scripts/omop_pipeline.sh \
102
+ test_results/top_p10000/generated_sequences/ \
103
+ test_results/top_p10000/restored_omop/ \
104
+ $OMOP_VOCAB_DIR
105
+ ```
106
+
58
107
  ## Citation
59
108
  ```
60
109
  @article{cehrgpt2024,
@@ -63,4 +112,3 @@ pip install .
63
112
  journal={arXiv preprint arXiv:2402.04400},
64
113
  year={2024}
65
114
  }
66
- ```
@@ -11,21 +11,22 @@ cehrgpt/analysis/privacy/nearest_neighbor_inference.py,sha256=qoJgWW7VsUMzjMGpTa
11
11
  cehrgpt/analysis/privacy/reid_inference.py,sha256=Pypd3QJXQNY8VljpnIEa5zeAbTZHMjQOazaL-9VsBGw,13955
12
12
  cehrgpt/analysis/privacy/utils.py,sha256=CRA4H9mPLBjMQGKzZ_x_3ro3tMap-NjsMDVqSOjHSVQ,8226
13
13
  cehrgpt/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- cehrgpt/data/hf_cehrgpt_dataset.py,sha256=7hvjjqE8WInVuRvAtNkFI_J-xluFBv1Ij4TPTdUxPM4,2570
15
- cehrgpt/data/hf_cehrgpt_dataset_collator.py,sha256=RYw5Isrwa4sdyQQ3Nf3cu7xPDA3m-c5ecCFf_y1TJKY,20497
16
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py,sha256=aQ0gsThOFhrh9ExpJhRmuiwN9ShIKheLgCIci-N7HOM,4305
14
+ cehrgpt/data/hf_cehrgpt_dataset.py,sha256=t9vpN05e--CiKgIlxLP0aLacISnvWWDPXtuFuJi3ksE,3736
15
+ cehrgpt/data/hf_cehrgpt_dataset_collator.py,sha256=DOvIF4Wzkd8-IO3zpIRZkX1j0IdvefaiSnrDn1YivCk,27912
16
+ cehrgpt/data/hf_cehrgpt_dataset_mapping.py,sha256=eI8CTk6yJ4DlNJWrNAkEmhWh353NeLqg5rwPpKqKT-U,17308
17
+ cehrgpt/data/sample_packing_sampler.py,sha256=0uKTbvtXpfS81esy_3epJ88eohyJPK46bfmxhle1fws,5419
17
18
  cehrgpt/generation/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
19
  cehrgpt/generation/chatgpt_generation.py,sha256=SrnLwHLdNtnAOEg36gNjqfoT9yd12iyPgpZffL2AFJo,4428
19
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py,sha256=-WLpKlulVVDJSdA2jXyp87gfLW4Q3aAtwULK8fDtn_E,11408
20
- cehrgpt/generation/omop_converter_batch.py,sha256=SDpWjqzi8dsgVzbbFes42GMdZEvrJ3sm4RbP5UpmIlk,25280
20
+ cehrgpt/generation/generate_batch_hf_gpt_sequence.py,sha256=uSEh8aMmPD61nGewIaPSkIqm-2AxDjCBiu4cBfxHxU4,11503
21
+ cehrgpt/generation/omop_converter_batch.py,sha256=-c0AlDVy5pJ5Afhr8ERiCHhoRrEk8ozJi3g0yFdWaMI,25348
21
22
  cehrgpt/generation/omop_entity.py,sha256=Q5Sr0AlyuPAm1FRPfnJO13q-u1fqRgYVHXruZ9g4xNE,19400
22
23
  cehrgpt/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
- cehrgpt/models/config.py,sha256=xek4W_siO7WtMAKE7zDsENotsIE70F8dcW-PTC0kBKk,9700
24
- cehrgpt/models/hf_cehrgpt.py,sha256=YrHhT8c92xcOVTb6FjFQokyHrDOcXgEDMBs0BksSBpA,75739
24
+ cehrgpt/models/config.py,sha256=Y3CiXZWniLP9_RlpU80Oe9gjn5leLmTYnNe_fWqfJLQ,10158
25
+ cehrgpt/models/hf_cehrgpt.py,sha256=3EQIOfa--oz4f8bM8KzbDi98G3XrUEQkox1vmBN001M,83321
25
26
  cehrgpt/models/hf_modeling_outputs.py,sha256=LaWa1jI6BRIKMEjWOy1QUeOfTur5y_p2c-JyuGVTdtw,10301
26
27
  cehrgpt/models/pretrained_embeddings.py,sha256=vLLVs17TLpXRqCVEWQxGGwPHkUJUO7laNTeBuyBK_yk,3238
27
28
  cehrgpt/models/special_tokens.py,sha256=-a7HPJBbdIH0qQ6B3CcRKqvpG6FZlm4nbVPTswGSJ4U,485
28
- cehrgpt/models/tokenization_hf_cehrgpt.py,sha256=jQR5aHjdHhS14nC1qnqDmybS1gpB27WK2-qVNz9cxW0,42156
29
+ cehrgpt/models/tokenization_hf_cehrgpt.py,sha256=jjCRqS29IzMnKp40jNOs80UKh2z9lK5S6M02GSB-4mk,42351
29
30
  cehrgpt/omop/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
31
  cehrgpt/omop/condition_era.py,sha256=hPZALz2XaWnro_1bwIYNkI48foOJjueyg3CZ1BliCno,626
31
32
  cehrgpt/omop/observation_period.py,sha256=TRMgv5Ya2RaS2im7oQ6BLC_5JL9EJYNYR62ApxIuHvg,1211
@@ -36,11 +37,18 @@ cehrgpt/omop/queries/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hS
36
37
  cehrgpt/omop/queries/condition_era.py,sha256=LFB6vBAvshHJxtYIRkl7cfrF0kf7ay0piBKpmHBwrpE,2578
37
38
  cehrgpt/omop/queries/observation_period.py,sha256=fpzr5DMNw-QLoSwp2Iatfch88E3hyhZ75usiIdG3A0U,6410
38
39
  cehrgpt/runners/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
39
- cehrgpt/runners/gpt_runner_util.py,sha256=88HKSVj-ADGBCMo7C3znKSMPnAAALa1iU_6P6i9sD0M,3867
40
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py,sha256=aGw87ZJuUIH196ryaZzt9D4hCAHVcDyKnvvdVPdipwc,31568
41
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py,sha256=6xulvnjwy6LDRPIL_zgsYH7sJMiXJ9AvFg3p2o35S6c,16510
42
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py,sha256=2l1X5bp1zckoFp0rQkxGptXyG8u3PgNw0dqYVDWLYjg,5155
43
- cehrgpt/runners/hyperparameter_search_util.py,sha256=i4qAb_22JO78l40MSyBPwDgAGuGc96efXmg_833cSSo,9044
40
+ cehrgpt/runners/data_utils.py,sha256=ScZZnfXwgXKaMvKgFzdb4vtQ7F_lw97O5uNsFbfsyP4,10620
41
+ cehrgpt/runners/gpt_runner_util.py,sha256=YJQSRW9Mo4TjXSOUOTf6BUFcs1MGFiXU5T4ztKZcYhU,3485
42
+ cehrgpt/runners/hf_cehrgpt_finetune_runner.py,sha256=bkPl30Y9CSXBlmMkH-3cA3-aW8XJK36Q-adx___WjkE,26921
43
+ cehrgpt/runners/hf_cehrgpt_pretrain_runner.py,sha256=ViVa_flEGdk_SO0psMR7ho-o79igsz_l1x80u81WJ3A,23875
44
+ cehrgpt/runners/hf_gpt_runner_argument_dataclass.py,sha256=VrqgDSiAMfGyHEIodoOg_8LU5O0ndWf9EE0YOKDFKKA,7019
45
+ cehrgpt/runners/hyperparameter_search_util.py,sha256=pWFmGo9Ezju4YmuZ-ohbAbYB0GGMfIDVUCyvcTxS1iU,9153
46
+ cehrgpt/runners/sample_packing_trainer.py,sha256=aezX30vxpP1DDcH5hO-yn395NqBKi2Xhb0mFNHi9OBs,7340
47
+ cehrgpt/simulations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
48
+ cehrgpt/simulations/generate_plots.py,sha256=BTZ71r8Kah0PMorkiO3vw55_p_9U1Z8KiD3GsPfaV0s,2520
49
+ cehrgpt/simulations/run_simulation.sh,sha256=DcJ6B19jIteUO0pZ0Tc21876lB9XxQHFAxlre7MtAzk,795
50
+ cehrgpt/simulations/time_embedding_simulation.py,sha256=HZ-imXH-bN-QYZN1PAfcERmNtaWIwKjbf0UrZduwCiA,8687
51
+ cehrgpt/simulations/time_token_simulation.py,sha256=sLg8vVXydvR_zk3BbqyrlA7sDIdhFnS-s5pSKcCilSc,6057
44
52
  cehrgpt/time_to_event/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
45
53
  cehrgpt/time_to_event/time_to_event_model.py,sha256=tfXa24l_0q1TBZ68BPRrHRC_3KRWYxrWGIv4myJlIb8,8497
46
54
  cehrgpt/time_to_event/time_to_event_prediction.py,sha256=Ajesq2gSsILghWHCTLiiBhcyOCa7m6JPPMdi_xvBlR4,12624
@@ -50,11 +58,15 @@ cehrgpt/time_to_event/config/next_visit_type_prediction.yaml,sha256=WMj2ZutEvHKI
50
58
  cehrgpt/time_to_event/config/t2dm_hf.yaml,sha256=_oMQzh2eJTYzEaMOpmhAzbX-qmdsKlkORELL6HxOxHo,202
51
59
  cehrgpt/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
52
60
  cehrgpt/tools/ehrshot_benchmark.py,sha256=E-m_5srlYEw7Y7i9twIJWDvrkwNlop-6yZB-80FZid0,2667
61
+ cehrgpt/tools/generate_causal_patient_split_by_age.py,sha256=dmHiPAL_kR1WrhRteIiHH9dwMtMi3PVl8jXm2O06_gI,4177
53
62
  cehrgpt/tools/generate_pretrained_embeddings.py,sha256=lhFSacGv8bMld6qigKZN8Op8eXpFi0DsJuQbWKOWXqI,4160
54
63
  cehrgpt/tools/merge_synthetic_real_dataasets.py,sha256=O1dbQ32Le0t15fwymwAh9mfNVLEWuFwW53DNvESrWbY,7589
55
64
  cehrgpt/tools/upload_omop_tables.py,sha256=vdBAbkeAsGPA4NsyhNjelPVj3gS8yzmS1sKNM1Qk96g,3791
56
- cehrgpt-0.0.1.dist-info/LICENSE,sha256=LOfC32zkfUIdGm8e_098jPbt8OHKtNWymDzxn2pA9Zk,1093
57
- cehrgpt-0.0.1.dist-info/METADATA,sha256=BZrsoZe0Smn4JoA3cCI63fC4nBvOVrC9sgZ0Ct1NJsA,3388
58
- cehrgpt-0.0.1.dist-info/WHEEL,sha256=nn6H5-ilmfVryoAQl3ZQ2l8SH5imPWFpm1A5FgEuFV4,91
59
- cehrgpt-0.0.1.dist-info/top_level.txt,sha256=akNCJBbMSLV8nkOzdVzdy13hMJ5CIQURnAS_YYEDVwA,17
60
- cehrgpt-0.0.1.dist-info/RECORD,,
65
+ cehrgpt/tools/linear_prob/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
66
+ cehrgpt/tools/linear_prob/compute_cehrgpt_features.py,sha256=jVgAmBrZKp7ABfqKkzwV5Vl_G9jDCjPl98NSVmSwHpE,19291
67
+ cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py,sha256=w0UvzMKYGenN_KDVnbzutmy8IPLUxW5hPvpKKxDSL5U,5820
68
+ cehrgpt-0.1.0.dist-info/licenses/LICENSE,sha256=LOfC32zkfUIdGm8e_098jPbt8OHKtNWymDzxn2pA9Zk,1093
69
+ cehrgpt-0.1.0.dist-info/METADATA,sha256=V02vsptjJRD_bybXVRFXPrJa-By9CX4j-oAA3EfXFq4,4933
70
+ cehrgpt-0.1.0.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
71
+ cehrgpt-0.1.0.dist-info/top_level.txt,sha256=akNCJBbMSLV8nkOzdVzdy13hMJ5CIQURnAS_YYEDVwA,17
72
+ cehrgpt-0.1.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.1)
2
+ Generator: setuptools (80.7.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5