cehrgpt 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. cehrgpt/analysis/htn_treatment_pathway.py +546 -0
  2. cehrgpt/analysis/treatment_pathway/__init__.py +0 -0
  3. cehrgpt/analysis/treatment_pathway/depression_treatment_pathway.py +94 -0
  4. cehrgpt/analysis/treatment_pathway/diabetes_treatment_pathway.py +94 -0
  5. cehrgpt/analysis/treatment_pathway/htn_treatment_pathway.py +94 -0
  6. cehrgpt/analysis/treatment_pathway/treatment_pathway.py +631 -0
  7. cehrgpt/data/cehrgpt_data_processor.py +549 -0
  8. cehrgpt/data/hf_cehrgpt_dataset.py +4 -0
  9. cehrgpt/data/hf_cehrgpt_dataset_collator.py +285 -652
  10. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +38 -5
  11. cehrgpt/generation/cehrgpt_conditional_generation.py +2 -0
  12. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +20 -12
  13. cehrgpt/generation/omop_converter_batch.py +11 -4
  14. cehrgpt/gpt_utils.py +73 -3
  15. cehrgpt/models/activations.py +27 -0
  16. cehrgpt/models/config.py +6 -2
  17. cehrgpt/models/gpt2.py +560 -0
  18. cehrgpt/models/hf_cehrgpt.py +183 -460
  19. cehrgpt/models/tokenization_hf_cehrgpt.py +380 -50
  20. cehrgpt/omop/ontology.py +154 -0
  21. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +24 -78
  22. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +48 -44
  23. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +46 -34
  24. cehrgpt/runners/hyperparameter_search_util.py +180 -69
  25. cehrgpt/runners/sample_packing_trainer.py +11 -2
  26. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +8 -2
  27. cehrgpt-0.1.3.dist-info/METADATA +238 -0
  28. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.3.dist-info}/RECORD +32 -22
  29. cehrgpt-0.1.2.dist-info/METADATA +0 -209
  30. /cehrgpt/tools/{merge_synthetic_real_dataasets.py → merge_synthetic_real_datasets.py} +0 -0
  31. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.3.dist-info}/WHEEL +0 -0
  32. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.3.dist-info}/licenses/LICENSE +0 -0
  33. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.3.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,9 @@
1
1
  __init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  cehrgpt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  cehrgpt/cehrgpt_args.py,sha256=zPLp9Qjlq5PapWx3R15BNnyaX8zV3dxr4PuWj71r0Lg,3516
4
- cehrgpt/gpt_utils.py,sha256=IA5qw-hxcKkGO07AB47lDNRU6mlb9jblpKO7KeLLN78,11342
4
+ cehrgpt/gpt_utils.py,sha256=gMPqHpOS7_6N81r7t_p6bGJ0FFVK5AgtEIMYLYKb9iA,13746
5
5
  cehrgpt/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ cehrgpt/analysis/htn_treatment_pathway.py,sha256=KMjSEdIFNr2bSAyw1W6_bh59aV067-ZhT-AymiKCyr8,21961
6
7
  cehrgpt/analysis/irregularity.py,sha256=Rfl_daMvSh9cZ68vUwfmuH-JYCFXdAph2ITHHffYC0Y,1047
7
8
  cehrgpt/analysis/privacy/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
9
  cehrgpt/analysis/privacy/attribute_inference.py,sha256=0ANVW0I5uvOl6IxQ15-vMVQd0mugOgSGReBUQQESImg,9368
@@ -11,29 +12,38 @@ cehrgpt/analysis/privacy/member_inference.py,sha256=a_-4rkYYffYl0ucnjK6uYy8jesup
11
12
  cehrgpt/analysis/privacy/nearest_neighbor_inference.py,sha256=qoJgWW7VsUMzjMGpTaK84iY_QLOuF3HCYXAEKLZOZsU,6391
12
13
  cehrgpt/analysis/privacy/reid_inference.py,sha256=Pypd3QJXQNY8VljpnIEa5zeAbTZHMjQOazaL-9VsBGw,13955
13
14
  cehrgpt/analysis/privacy/utils.py,sha256=CRA4H9mPLBjMQGKzZ_x_3ro3tMap-NjsMDVqSOjHSVQ,8226
15
+ cehrgpt/analysis/treatment_pathway/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
+ cehrgpt/analysis/treatment_pathway/depression_treatment_pathway.py,sha256=7mrzaMBv09Gn6I5OM86f7gNfPvncVVKg2C3jZo0bmsU,3024
17
+ cehrgpt/analysis/treatment_pathway/diabetes_treatment_pathway.py,sha256=qwAtJ3KVesvqvR22Tbk19k35sDL-sGlRZo2sjJNo3yQ,2962
18
+ cehrgpt/analysis/treatment_pathway/htn_treatment_pathway.py,sha256=0bsEE1VFIxzU33bSipM30p2fnHsWjGWWcu59y_38K3c,2870
19
+ cehrgpt/analysis/treatment_pathway/treatment_pathway.py,sha256=SCWphYH9ARa4ZKB9fgBYM9RC2Hc8PDwtoHHCX7th16Q,25496
14
20
  cehrgpt/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
- cehrgpt/data/hf_cehrgpt_dataset.py,sha256=hwJlGW7XiJIr6cXtmwvReQf9yLZJPD-dvJGvRg5ERqU,3755
16
- cehrgpt/data/hf_cehrgpt_dataset_collator.py,sha256=juM5HeZScgj8w15Bl1qC83Swld4gY6avh0QkSWLqITA,45465
17
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py,sha256=_QDX9NXfmQ_S3kOf3yndb3AhoEeFiSzAOv836uYW0AY,26230
21
+ cehrgpt/data/cehrgpt_data_processor.py,sha256=0Y6GPWu6fRBLemXJu5IxuOPbF2wmSrX-18uyofTeUzk,23096
22
+ cehrgpt/data/hf_cehrgpt_dataset.py,sha256=uz05TG5QCl3_Ybn9zZyWRg0pEbiAvL1yPWXK3BGsj0Q,3815
23
+ cehrgpt/data/hf_cehrgpt_dataset_collator.py,sha256=2UcYB241dWhvS-mV0ZTbCJdjlgPrVjZOAh3V8EWFfCg,27930
24
+ cehrgpt/data/hf_cehrgpt_dataset_mapping.py,sha256=-Igd-P-yvYlJXGZSGlYHRnez464NCkZIko3boQDYS1E,27638
18
25
  cehrgpt/data/sample_packing_sampler.py,sha256=vovGMtmhG70DRkSCeiaDEJ_rjKZ38y-YLaI1kkhFEkI,6747
19
26
  cehrgpt/generation/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
- cehrgpt/generation/cehrgpt_conditional_generation.py,sha256=AM76yaPyw1B-bcdei24HO0uspGZWHGKWpYpHywotTIQ,11972
27
+ cehrgpt/generation/cehrgpt_conditional_generation.py,sha256=6I4tI-cCQ6QdFxhDAkhu0ZNo57DINjD-NncxMbyUwgg,12032
21
28
  cehrgpt/generation/chatgpt_generation.py,sha256=SrnLwHLdNtnAOEg36gNjqfoT9yd12iyPgpZffL2AFJo,4428
22
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py,sha256=P8al4-zqymqEkCHCCu2sqz_45akcKF2o_AtQIjJdVmQ,11919
23
- cehrgpt/generation/omop_converter_batch.py,sha256=LUmCD-t_6ZP1YfNDZCqYewl-XIIaIgRZ_dAxuR_VdCQ,26275
29
+ cehrgpt/generation/generate_batch_hf_gpt_sequence.py,sha256=lpKEvJ2hhB8bwS06c5jEAksFUrGKCUv6t7hXrsMj-Ns,12284
30
+ cehrgpt/generation/omop_converter_batch.py,sha256=h4dg9fc23w6i82KMrOQFM-KxD6iuLnJfrv7YISc0dMw,26620
24
31
  cehrgpt/generation/omop_entity.py,sha256=Q5Sr0AlyuPAm1FRPfnJO13q-u1fqRgYVHXruZ9g4xNE,19400
25
32
  cehrgpt/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
- cehrgpt/models/config.py,sha256=nOAKgH5420HLCcy7n1hE7MbqR861Iq4DTutKoAd25tg,11090
27
- cehrgpt/models/hf_cehrgpt.py,sha256=3P7bOLDr7NMSedGszhmlJJN4Mhpd_65-x6uzwvSjigE,92837
33
+ cehrgpt/models/activations.py,sha256=crVPS-cZpUGrvLD7xhNjGmGr9S4e4LEfNmgIEsiuQ88,981
34
+ cehrgpt/models/config.py,sha256=SwsHVXzsgDmFSfrzv90lZBePenoHv-fIGGSLdxAIiu8,11193
35
+ cehrgpt/models/gpt2.py,sha256=4H9sFzf_qFGY-Bk0mfztxlKJXxvA0kTKwKiWFbqJLrQ,22079
36
+ cehrgpt/models/hf_cehrgpt.py,sha256=YTZtY1p-M-utQa6iJvDXFOjgc1SDdL3ZcWuy_-ZN41g,81167
28
37
  cehrgpt/models/hf_modeling_outputs.py,sha256=5X4WEYKqT37phv_e5ZAv3A_N0wqdAUJLJRm6TxS6dDQ,10356
29
38
  cehrgpt/models/pretrained_embeddings.py,sha256=vLLVs17TLpXRqCVEWQxGGwPHkUJUO7laNTeBuyBK_yk,3238
30
39
  cehrgpt/models/special_tokens.py,sha256=lrw45B4tea4Dsajn09Cz6w5D2TfHmYXikZkgwnstu_o,521
31
- cehrgpt/models/tokenization_hf_cehrgpt.py,sha256=cAxHTctpVBxfWfC3XcwDQavN1zwWN9Nid_Fajd5zQWQ,53159
40
+ cehrgpt/models/tokenization_hf_cehrgpt.py,sha256=yHuNXvLznaSjwxVJsq7r9bZLi4msM8n4LVrzHINqsgY,66225
32
41
  cehrgpt/omop/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
33
42
  cehrgpt/omop/condition_era.py,sha256=hPZALz2XaWnro_1bwIYNkI48foOJjueyg3CZ1BliCno,626
34
43
  cehrgpt/omop/observation_period.py,sha256=TRMgv5Ya2RaS2im7oQ6BLC_5JL9EJYNYR62ApxIuHvg,1211
35
44
  cehrgpt/omop/omop_argparse.py,sha256=WI_-vZGfPdZ8atIeB-CrpaPdkv07kDBabyEpaRZfl64,998
36
45
  cehrgpt/omop/omop_table_builder.py,sha256=6K_YYKyayDUBwxUdwaliI5tufpfIQqByDY5HeBbjHok,2742
46
+ cehrgpt/omop/ontology.py,sha256=LZIp0X3gY_VDZqIl6gTwGq7ZwV1nb0raPLTQAbJm6nM,5683
37
47
  cehrgpt/omop/sample_omop_tables.py,sha256=2JZ8BNSvssceinwFanvuCRh-YlKrKn25U9w1pL79kQ0,2300
38
48
  cehrgpt/omop/queries/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
39
49
  cehrgpt/omop/queries/condition_era.py,sha256=LFB6vBAvshHJxtYIRkl7cfrF0kf7ay0piBKpmHBwrpE,2578
@@ -41,11 +51,11 @@ cehrgpt/omop/queries/observation_period.py,sha256=fpzr5DMNw-QLoSwp2Iatfch88E3hyh
41
51
  cehrgpt/runners/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
42
52
  cehrgpt/runners/data_utils.py,sha256=i-krtBx_6rvPYtdLdDoWwOTtJcaovd0wH8gBYmgN2l4,16013
43
53
  cehrgpt/runners/gpt_runner_util.py,sha256=YJQSRW9Mo4TjXSOUOTf6BUFcs1MGFiXU5T4ztKZcYhU,3485
44
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py,sha256=1OgxLm4T7iHv5pKi2QaSdaz9ogWo2n3sSUGp6cHDF9s,28309
45
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py,sha256=ERSnvB38fPYVghtKQeNTZ8VfeXnoRcCHB0cWISWaZ84,26523
46
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py,sha256=fJR4RHPqal1YI6_KUH-WlkoQLSZuBT5bKUGfPHDFrWI,9350
47
- cehrgpt/runners/hyperparameter_search_util.py,sha256=YWdFQ1igQs-G_wqWUrUzYraGiz8OSpSYyvid-I5nhWA,9262
48
- cehrgpt/runners/sample_packing_trainer.py,sha256=Zb7Aqwnk8-VqrjEKUVeg5XzZWmHxXOU2sDn1YURS-FU,7960
54
+ cehrgpt/runners/hf_cehrgpt_finetune_runner.py,sha256=AY9QxH4WupfWpLm9rjeSMOzedmw_03kTWuhncVRuhqs,26032
55
+ cehrgpt/runners/hf_cehrgpt_pretrain_runner.py,sha256=I_fuuKNzWx6yZiDcAAZdQtyxUEgNKLygQyS-SyQpptY,26840
56
+ cehrgpt/runners/hf_gpt_runner_argument_dataclass.py,sha256=8qHVUp-hx7xKozaE_EaEJphrs1QfRSXx0P6YMByK9Ww,9981
57
+ cehrgpt/runners/hyperparameter_search_util.py,sha256=SD02j1D8IBtIOG41dh7VgmVT2SWCF-VPZ7zVHlEIN70,12801
58
+ cehrgpt/runners/sample_packing_trainer.py,sha256=HfxHCIGBXb1RbN7nbU6jmSy_Zzwx_joj-UoYqbKl5-0,8375
49
59
  cehrgpt/simulations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
50
60
  cehrgpt/simulations/generate_plots.py,sha256=BTZ71r8Kah0PMorkiO3vw55_p_9U1Z8KiD3GsPfaV0s,2520
51
61
  cehrgpt/simulations/run_simulation.sh,sha256=DcJ6B19jIteUO0pZ0Tc21876lB9XxQHFAxlre7MtAzk,795
@@ -63,13 +73,13 @@ cehrgpt/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
63
73
  cehrgpt/tools/ehrshot_benchmark.py,sha256=E-m_5srlYEw7Y7i9twIJWDvrkwNlop-6yZB-80FZid0,2667
64
74
  cehrgpt/tools/generate_causal_patient_split_by_age.py,sha256=dmHiPAL_kR1WrhRteIiHH9dwMtMi3PVl8jXm2O06_gI,4177
65
75
  cehrgpt/tools/generate_pretrained_embeddings.py,sha256=lhFSacGv8bMld6qigKZN8Op8eXpFi0DsJuQbWKOWXqI,4160
66
- cehrgpt/tools/merge_synthetic_real_dataasets.py,sha256=O1dbQ32Le0t15fwymwAh9mfNVLEWuFwW53DNvESrWbY,7589
76
+ cehrgpt/tools/merge_synthetic_real_datasets.py,sha256=O1dbQ32Le0t15fwymwAh9mfNVLEWuFwW53DNvESrWbY,7589
67
77
  cehrgpt/tools/upload_omop_tables.py,sha256=vdBAbkeAsGPA4NsyhNjelPVj3gS8yzmS1sKNM1Qk96g,3791
68
78
  cehrgpt/tools/linear_prob/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
69
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py,sha256=Hpx7WvAWm2WwPHFfimCADXh019I7bwdzJ4_5_YCxQzU,19817
79
+ cehrgpt/tools/linear_prob/compute_cehrgpt_features.py,sha256=0i34zAwePG0hZK2HSDaUlO-Fzyb5K4LqRuhrCVWivxA,19906
70
80
  cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py,sha256=w0UvzMKYGenN_KDVnbzutmy8IPLUxW5hPvpKKxDSL5U,5820
71
- cehrgpt-0.1.2.dist-info/licenses/LICENSE,sha256=LOfC32zkfUIdGm8e_098jPbt8OHKtNWymDzxn2pA9Zk,1093
72
- cehrgpt-0.1.2.dist-info/METADATA,sha256=D7gGKrQThiLivViFeNm711NCP8J-wXfkueMGb6RKqV0,8481
73
- cehrgpt-0.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
74
- cehrgpt-0.1.2.dist-info/top_level.txt,sha256=akNCJBbMSLV8nkOzdVzdy13hMJ5CIQURnAS_YYEDVwA,17
75
- cehrgpt-0.1.2.dist-info/RECORD,,
81
+ cehrgpt-0.1.3.dist-info/licenses/LICENSE,sha256=LOfC32zkfUIdGm8e_098jPbt8OHKtNWymDzxn2pA9Zk,1093
82
+ cehrgpt-0.1.3.dist-info/METADATA,sha256=MTgv1L9ru4evziAW2yTLsd3m9d1Ept8xy85u2CpBNTM,10167
83
+ cehrgpt-0.1.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
84
+ cehrgpt-0.1.3.dist-info/top_level.txt,sha256=akNCJBbMSLV8nkOzdVzdy13hMJ5CIQURnAS_YYEDVwA,17
85
+ cehrgpt-0.1.3.dist-info/RECORD,,
@@ -1,209 +0,0 @@
1
- Metadata-Version: 2.4
2
- Name: cehrgpt
3
- Version: 0.1.2
4
- Summary: CEHR-GPT: Generating Electronic Health Records with Chronological Patient Timelines
5
- Author-email: Chao Pang <chaopang229@gmail.com>, Xinzhuo Jiang <xj2193@cumc.columbia.edu>, Krishna Kalluri <kk3326@cumc.columbia.edu>, Elise Minto <em3697@cumc.columbia.edu>, Jason Patterson <jp3477@cumc.columbia.edu>, Nishanth Parameshwar Pavinkurve <np2689@cumc.columbia.edu>, Karthik Natarajan <kn2174@cumc.columbia.edu>
6
- License: MIT License
7
- Classifier: Development Status :: 5 - Production/Stable
8
- Classifier: Intended Audience :: Developers
9
- Classifier: Intended Audience :: Science/Research
10
- Classifier: License :: OSI Approved :: MIT License
11
- Classifier: Programming Language :: Python :: 3
12
- Requires-Python: >=3.10.0
13
- Description-Content-Type: text/markdown
14
- License-File: LICENSE
15
- Requires-Dist: cehrbert==1.4.5
16
- Requires-Dist: cehrbert_data==0.0.11
17
- Requires-Dist: openai==1.54.3
18
- Requires-Dist: optuna==4.0.0
19
- Requires-Dist: transformers==4.44.1
20
- Requires-Dist: tokenizers==0.19.0
21
- Requires-Dist: peft==0.10.0
22
- Requires-Dist: lightgbm
23
- Requires-Dist: polars
24
- Provides-Extra: dev
25
- Requires-Dist: pre-commit; extra == "dev"
26
- Requires-Dist: pytest; extra == "dev"
27
- Requires-Dist: pytest-cov; extra == "dev"
28
- Requires-Dist: pytest-subtests; extra == "dev"
29
- Requires-Dist: rootutils; extra == "dev"
30
- Requires-Dist: hypothesis; extra == "dev"
31
- Requires-Dist: black; extra == "dev"
32
- Provides-Extra: flash-attn
33
- Requires-Dist: flash_attn; extra == "flash-attn"
34
- Dynamic: license-file
35
-
36
- # CEHRGPT
37
-
38
- [![PyPI - Version](https://img.shields.io/pypi/v/cehrgpt)](https://pypi.org/project/cehrgpt/)
39
- ![Python](https://img.shields.io/badge/-Python_3.11-blue?logo=python&logoColor=white)
40
- [![tests](https://github.com/knatarajan-lab/cehrgpt/actions/workflows/tests.yaml/badge.svg)](https://github.com/knatarajan-lab/cehrgpt/actions/workflows/tests.yaml)
41
- [![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray)](https://github.com/knatarajan-lab/cehrgpt/blob/main/LICENSE)
42
- [![contributors](https://img.shields.io/github/contributors/knatarajan-lab/cehrgpt.svg)](https://github.com/knatarajan-lab/cehrgpt/graphs/contributors)
43
-
44
- ## Description
45
- CEHRGPT is a synthetic data generation model developed to handle structured electronic health records (EHR) with enhanced privacy and reliability. It leverages state-of-the-art natural language processing techniques to create realistic, anonymized patient data that can be used for research and development without compromising patient privacy.
46
-
47
- ## Features
48
- - **Synthetic Patient Data Generation**: Generates comprehensive patient profiles including demographics, medical history, treatment courses, and outcomes.
49
- - **Privacy-Preserving**: Implements techniques to ensure the generated data does not reveal identifiable information.
50
- - **Compatibility with OMOP**: Fully compatible with the OMOP common data model, allowing seamless integration with existing healthcare data systems.
51
- - **Extensible**: Designed to be adaptable to new datasets and different EHR systems.
52
-
53
- ## Installation
54
- To install CEHRGPT, clone this repository and install the required dependencies.
55
-
56
- ```bash
57
- git clone https://github.com/knatarajan-lab/cehrgpt.git
58
- cd cehrgpt
59
- pip install .
60
- ```
61
-
62
- ## Pretrain
63
- Pretrain cehrgpt using the Hugging Face trainer, the parameters can be found in the sample configuration yaml
64
- ```bash
65
- mkdir test_results
66
- # This is NOT required when streaming is set to true
67
- mkdir test_dataset_prepared
68
- python -u -m cehrgpt.runners.hf_cehrgpt_pretrain_runner sample_configs/cehrgpt_pretrain_sample_config.yaml
69
- ```
70
-
71
- ## Generate synthetic sequences
72
- Generate synthetic sequences using the trained model
73
- ```bash
74
- export TRANSFORMERS_VERBOSITY=info
75
- export CUDA_VISIBLE_DEVICES="0"
76
- python -u -m cehrgpt.generation.generate_batch_hf_gpt_sequence \
77
- --model_folder test_results \
78
- --tokenizer_folder test_results \
79
- --output_folder test_results \
80
- --num_of_patients 128 \
81
- --batch_size 32 \
82
- --buffer_size 128 \
83
- --context_window 1024 \
84
- --sampling_strategy TopPStrategy \
85
- --top_p 1.0 --temperature 1.0 --repetition_penalty 1.0 \
86
- --epsilon_cutoff 0.00 \
87
- --demographic_data_path sample_data/pretrain
88
- ```
89
-
90
- ## Convert synthetic sequences to OMOP
91
- ```bash
92
- # omop converter requires the OHDSI vocabulary
93
- export OMOP_VOCAB_DIR = ""
94
- # the omop derived tables need to be built using pyspark
95
- export SPARK_WORKER_INSTANCES="1"
96
- export SPARK_WORKER_CORES="8"
97
- export SPARK_EXECUTOR_CORES="2"
98
- export SPARK_DRIVER_MEMORY="2g"
99
- export SPARK_EXECUTOR_MEMORY="2g"
100
-
101
- # Convert the sequences, create the omop derived tables
102
- sh scripts/omop_pipeline.sh \
103
- test_results/top_p10000/generated_sequences/ \
104
- test_results/top_p10000/restored_omop/ \
105
- $OMOP_VOCAB_DIR
106
- ```
107
-
108
- # MEDS Support
109
-
110
- This section demonstrates how to pretrain CEHR-GPT using MIMIC-IV data in the MEDS (Medical Event Data Standard) format.
111
-
112
- ## Prerequisites
113
-
114
- Set up the required environment variables before beginning:
115
-
116
- ```bash
117
- export CEHR_GPT_MODEL_DIR="" # Path to CEHR-GPT model directory
118
- export MEDS_DIR="" # Path to MEDS data directory
119
- export MEDS_READER_DIR="" # Path to MEDS reader output directory
120
- ```
121
-
122
- ## Step 1: Create MIMIC MEDS Data
123
-
124
- Transform your MIMIC files into MEDS format by following the instructions in the [MEDS_transforms](https://github.com/mmcdermott/MEDS_transforms/) repository.
125
-
126
- ## Step 2: Create the MEDS Reader
127
-
128
- Convert the MEDS data for use with CEHR-GPT:
129
-
130
- ```bash
131
- meds_reader_convert $MEDS_DIR $MEDS_READER_DIR --num_threads 10
132
- ```
133
-
134
- ## Step 3: Pretrain CEHR-GPT
135
-
136
- Run the pretraining process using the prepared MEDS data:
137
-
138
- ```bash
139
- python -u -m cehrgpt.runners.hf_cehrgpt_pretrain_runner \
140
- --model_name_or_path $CEHR_GPT_MODEL_DIR \
141
- --tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
142
- --output_dir $CEHR_GPT_MODEL_DIR \
143
- --data_folder $MEDS_READER_DIR \
144
- --dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
145
- --do_train true --seed 42 \
146
- --dataloader_num_workers 16 --dataloader_prefetch_factor 8 \
147
- --hidden_size 768 --num_hidden_layers 14 --max_position_embeddings 8192 \
148
- --evaluation_strategy epoch --save_strategy epoch \
149
- --sample_packing --max_tokens_per_batch 16384 \
150
- --warmup_steps 500 --weight_decay 0.01 \
151
- --num_train_epochs 50 --learning_rate 0.0002 \
152
- --use_early_stopping --early_stopping_threshold 0.001 \
153
- --is_data_in_meds --inpatient_att_function_type day \
154
- --att_function_type day --include_inpatient_hour_token \
155
- --include_auxiliary_token --include_demographic_prompt \
156
- --meds_to_cehrbert_conversion_type "MedsToBertMimic4"
157
- ```
158
-
159
- ## Step 4: Generate MEDS Trajectories
160
-
161
- ### Environment Setup for Trajectory Generation
162
-
163
- Configure additional environment variables for trajectory generation with task labels (`subject_id`, `prediction_time`, `boolean_value` [optional]):
164
-
165
- ```bash
166
- # MEDS_LABEL_COHORT_DIR must contain a set of parquet files
167
- export MEDS_LABEL_COHORT_DIR="" # Path to cohort labels directory
168
- export MEDS_TRAJECTORY_DIR="" # Path for trajectory output
169
- ```
170
-
171
- ### Generate Trajectories
172
-
173
- Create synthetic patient trajectories using the trained model:
174
-
175
- > **Important:** The total sequence length (`generation_input_length` + `generation_max_new_tokens`) cannot exceed the `max_position_embeddings` value (8192) defined during pretraining.
176
-
177
- ```bash
178
- python -u -m cehrgpt.generation.cehrgpt_conditional_generation \
179
- --cohort_folder $MEDS_LABEL_COHORT_DIR \
180
- --data_folder $MEDS_READER_DIR \
181
- --dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
182
- --model_name_or_path $CEHR_GPT_MODEL_DIR \
183
- --tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
184
- --output_dir $MEDS_TRAJECTORY_DIR \
185
- --per_device_eval_batch_size 16 \
186
- --num_of_trajectories_per_sample 2 \
187
- --generation_input_length 4096 \
188
- --generation_max_new_tokens 4096 \
189
- --is_data_in_meds \
190
- --att_function_type day --inpatient_att_function_type day \
191
- --meds_to_cehrbert_conversion_type MedsToBertMimic4 \
192
- --include_auxiliary_token --include_demographic_prompt \
193
- --include_inpatient_hour_token
194
- ```
195
-
196
- ### Parameters Explanation
197
-
198
- - `generation_input_length`: Controls the length of input context for generation
199
- - `generation_max_new_tokens`: Maximum number of new tokens to generate
200
- - `num_of_trajectories_per_sample`: Number of trajectories to generate per patient sample
201
-
202
- ## Citation
203
- ```
204
- @article{cehrgpt2024,
205
- title={CEHRGPT: Synthetic Data Generation for Electronic Health Records},
206
- author={Natarajan, K and others},
207
- journal={arXiv preprint arXiv:2402.04400},
208
- year={2024}
209
- }