cehrgpt 0.0.2__tar.gz → 0.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/.gitignore +3 -0
  2. {cehrgpt-0.0.2/src/cehrgpt.egg-info → cehrgpt-0.1.1}/PKG-INFO +11 -8
  3. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/README.md +4 -4
  4. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/pyproject.toml +5 -3
  5. cehrgpt-0.1.1/sample_data/omop_vocab/concept/concept.parquet +0 -0
  6. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/scripts/level_three_evaluation.sh +10 -6
  7. cehrgpt-0.1.1/scripts/run_linear_prob.sh +260 -0
  8. cehrgpt-0.1.1/src/cehrgpt/analysis/irregularity.py +36 -0
  9. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/data/hf_cehrgpt_dataset.py +25 -4
  10. cehrgpt-0.1.1/src/cehrgpt/data/hf_cehrgpt_dataset_collator.py +1020 -0
  11. cehrgpt-0.1.1/src/cehrgpt/data/hf_cehrgpt_dataset_mapping.py +595 -0
  12. cehrgpt-0.1.1/src/cehrgpt/data/sample_packing_sampler.py +181 -0
  13. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
  14. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/generation/omop_converter_batch.py +32 -2
  15. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/gpt_utils.py +20 -2
  16. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/models/config.py +35 -0
  17. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/models/hf_cehrgpt.py +470 -106
  18. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/models/hf_modeling_outputs.py +1 -0
  19. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/models/special_tokens.py +1 -0
  20. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/models/tokenization_hf_cehrgpt.py +358 -71
  21. cehrgpt-0.1.1/src/cehrgpt/runners/data_utils.py +358 -0
  22. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/runners/gpt_runner_util.py +0 -10
  23. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/runners/hf_cehrgpt_finetune_runner.py +181 -283
  24. cehrgpt-0.1.1/src/cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +582 -0
  25. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +90 -0
  26. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/runners/hyperparameter_search_util.py +10 -8
  27. cehrgpt-0.1.1/src/cehrgpt/runners/sample_packing_trainer.py +185 -0
  28. cehrgpt-0.1.1/src/cehrgpt/simulations/generate_plots.py +95 -0
  29. cehrgpt-0.1.1/src/cehrgpt/simulations/run_simulation.sh +24 -0
  30. cehrgpt-0.1.1/src/cehrgpt/simulations/time_embedding_simulation.py +250 -0
  31. cehrgpt-0.1.1/src/cehrgpt/simulations/time_token_simulation.py +177 -0
  32. cehrgpt-0.1.1/src/cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
  33. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/time_to_event/time_to_event_model.py +2 -13
  34. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
  35. cehrgpt-0.1.1/src/cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +495 -0
  36. cehrgpt-0.1.1/src/cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
  37. {cehrgpt-0.0.2 → cehrgpt-0.1.1/src/cehrgpt.egg-info}/PKG-INFO +11 -8
  38. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt.egg-info/SOURCES.txt +17 -9
  39. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt.egg-info/requires.txt +5 -3
  40. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/tests/integration_tests/runners/hf_cehrgpt_pretrain_runner_test.py +26 -11
  41. cehrgpt-0.1.1/tests/integration_tests/runners/hf_cehrgpt_pretrain_sample_packing_runner_test.py +122 -0
  42. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/tests/integration_tests/runners/hf_cehrgpt_pretrain_sfm_runner_test.py +9 -3
  43. cehrgpt-0.1.1/tests/unit_tests/models/model_utils_test.py +131 -0
  44. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/tests/unit_tests/runners/hf_cehrgpt_finetune_runner_test.py +4 -4
  45. cehrgpt-0.1.1/tests/unit_tests/tools/__init__.py +0 -0
  46. cehrgpt-0.0.2/src/cehrgpt/data/hf_cehrgpt_dataset_collator.py +0 -482
  47. cehrgpt-0.0.2/src/cehrgpt/data/hf_cehrgpt_dataset_mapping.py +0 -382
  48. cehrgpt-0.0.2/src/cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
  49. cehrgpt-0.0.2/src/cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
  50. cehrgpt-0.0.2/src/cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
  51. cehrgpt-0.0.2/src/cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
  52. cehrgpt-0.0.2/src/cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
  53. cehrgpt-0.0.2/src/cehrgpt/rl_finetune/ppo_finetune.py +0 -394
  54. cehrgpt-0.0.2/src/cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
  55. cehrgpt-0.0.2/src/cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
  56. cehrgpt-0.0.2/src/cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +0 -406
  57. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/.github/workflows/build-python.yaml +0 -0
  58. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/.github/workflows/tests.yaml +0 -0
  59. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/.pre-commit-config.yaml +0 -0
  60. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/LICENSE +0 -0
  61. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/sample_configs/cehrgpt_pretrain_sample_config.yaml +0 -0
  62. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/sample_data/pretrain/patient_sequence.parquet +0 -0
  63. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/sample_data/pretrained_embeddings/pretrained_embedding_concepts.pkl +0 -0
  64. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/sample_data/pretrained_embeddings/pretrained_embedding_vectors.npy +0 -0
  65. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/scripts/omop_pipeline.sh +0 -0
  66. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/scripts/pool_generated_sequences.sh +0 -0
  67. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/setup.cfg +0 -0
  68. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/__init__.py +0 -0
  69. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/__init__.py +0 -0
  70. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/analysis/__init__.py +0 -0
  71. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/analysis/privacy/__init__.py +0 -0
  72. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/analysis/privacy/attribute_inference.py +0 -0
  73. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/analysis/privacy/attribute_inference_config.yml +0 -0
  74. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/analysis/privacy/member_inference.py +0 -0
  75. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/analysis/privacy/nearest_neighbor_inference.py +0 -0
  76. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/analysis/privacy/reid_inference.py +0 -0
  77. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/analysis/privacy/utils.py +0 -0
  78. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/cehrgpt_args.py +0 -0
  79. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/data/__init__.py +0 -0
  80. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/generation/__init__.py +0 -0
  81. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/generation/chatgpt_generation.py +0 -0
  82. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/generation/omop_entity.py +0 -0
  83. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/models/__init__.py +0 -0
  84. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/models/pretrained_embeddings.py +0 -0
  85. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/omop/__init__.py +0 -0
  86. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/omop/condition_era.py +0 -0
  87. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/omop/observation_period.py +0 -0
  88. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/omop/omop_argparse.py +0 -0
  89. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/omop/omop_table_builder.py +0 -0
  90. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/omop/queries/__init__.py +0 -0
  91. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/omop/queries/condition_era.py +0 -0
  92. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/omop/queries/observation_period.py +0 -0
  93. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/omop/sample_omop_tables.py +0 -0
  94. {cehrgpt-0.0.2/src/cehrgpt/rl_finetune → cehrgpt-0.1.1/src/cehrgpt/runners}/__init__.py +0 -0
  95. {cehrgpt-0.0.2/src/cehrgpt/runners → cehrgpt-0.1.1/src/cehrgpt/simulations}/__init__.py +0 -0
  96. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/time_to_event/__init__.py +0 -0
  97. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/time_to_event/config/30_day_readmission.yaml +0 -0
  98. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +0 -0
  99. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/time_to_event/config/t2dm_hf.yaml +0 -0
  100. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/time_to_event/time_to_event_utils.py +0 -0
  101. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/tools/__init__.py +0 -0
  102. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/tools/ehrshot_benchmark.py +0 -0
  103. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/tools/generate_causal_patient_split_by_age.py +0 -0
  104. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/tools/generate_pretrained_embeddings.py +0 -0
  105. {cehrgpt-0.0.2/tests → cehrgpt-0.1.1/src/cehrgpt/tools/linear_prob}/__init__.py +0 -0
  106. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/tools/merge_synthetic_real_dataasets.py +0 -0
  107. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/tools/upload_omop_tables.py +0 -0
  108. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt.egg-info/dependency_links.txt +0 -0
  109. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt.egg-info/top_level.txt +0 -0
  110. {cehrgpt-0.0.2/tests/integration_tests → cehrgpt-0.1.1/tests}/__init__.py +0 -0
  111. {cehrgpt-0.0.2/tests/integration_tests/runners → cehrgpt-0.1.1/tests/integration_tests}/__init__.py +0 -0
  112. {cehrgpt-0.0.2/tests/unit_tests → cehrgpt-0.1.1/tests/integration_tests/runners}/__init__.py +0 -0
  113. {cehrgpt-0.0.2/tests/unit_tests/models → cehrgpt-0.1.1/tests/unit_tests}/__init__.py +0 -0
  114. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/tests/unit_tests/gpt_utils_test.py +0 -0
  115. {cehrgpt-0.0.2/tests/unit_tests/models/tokenization → cehrgpt-0.1.1/tests/unit_tests/models}/__init__.py +0 -0
  116. {cehrgpt-0.0.2/tests/unit_tests/runners → cehrgpt-0.1.1/tests/unit_tests/models/tokenization}/__init__.py +0 -0
  117. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/tests/unit_tests/models/tokenization/create_bins_with_spline_test.py +0 -0
  118. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/tests/unit_tests/models/tokenization/create_sample_from_bins_test.py +0 -0
  119. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/tests/unit_tests/numeric_concept_statistics_test.py +0 -0
  120. {cehrgpt-0.0.2/tests/unit_tests/tools → cehrgpt-0.1.1/tests/unit_tests/runners}/__init__.py +0 -0
  121. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/tests/unit_tests/tokenization_test.py +0 -0
  122. {cehrgpt-0.0.2 → cehrgpt-0.1.1}/tests/unit_tests/tools/upload_omop_tables_test.py +0 -0
@@ -4,6 +4,9 @@
4
4
  venv*
5
5
  dist/*
6
6
 
7
+ *png
8
+ *json
9
+
7
10
  *ipynb_checkpoints/
8
11
  *h5
9
12
  *logs
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: cehrgpt
3
- Version: 0.0.2
3
+ Version: 0.1.1
4
4
  Summary: CEHR-GPT: Generating Electronic Health Records with Chronological Patient Timelines
5
5
  Author-email: Chao Pang <chaopang229@gmail.com>, Xinzhuo Jiang <xj2193@cumc.columbia.edu>, Krishna Kalluri <kk3326@cumc.columbia.edu>, Elise Minto <em3697@cumc.columbia.edu>, Jason Patterson <jp3477@cumc.columbia.edu>, Nishanth Parameshwar Pavinkurve <np2689@cumc.columbia.edu>, Karthik Natarajan <kn2174@cumc.columbia.edu>
6
6
  License: MIT License
@@ -12,13 +12,15 @@ Classifier: Programming Language :: Python :: 3
12
12
  Requires-Python: >=3.10.0
13
13
  Description-Content-Type: text/markdown
14
14
  License-File: LICENSE
15
- Requires-Dist: cehrbert==1.3.8
15
+ Requires-Dist: cehrbert==1.4.5
16
+ Requires-Dist: cehrbert_data==0.0.11
16
17
  Requires-Dist: openai==1.54.3
17
18
  Requires-Dist: optuna==4.0.0
18
- Requires-Dist: transformers==4.40.0
19
+ Requires-Dist: transformers==4.44.1
19
20
  Requires-Dist: tokenizers==0.19.0
20
21
  Requires-Dist: peft==0.10.0
21
- Requires-Dist: trl==0.11.4
22
+ Requires-Dist: lightgbm
23
+ Requires-Dist: polars
22
24
  Provides-Extra: dev
23
25
  Requires-Dist: pre-commit; extra == "dev"
24
26
  Requires-Dist: pytest; extra == "dev"
@@ -29,14 +31,15 @@ Requires-Dist: hypothesis; extra == "dev"
29
31
  Requires-Dist: black; extra == "dev"
30
32
  Provides-Extra: flash-attn
31
33
  Requires-Dist: flash_attn; extra == "flash-attn"
34
+ Dynamic: license-file
32
35
 
33
36
  # CEHRGPT
34
37
 
35
38
  [![PyPI - Version](https://img.shields.io/pypi/v/cehrgpt)](https://pypi.org/project/cehrgpt/)
36
39
  ![Python](https://img.shields.io/badge/-Python_3.11-blue?logo=python&logoColor=white)
37
- [![tests](https://github.com/knatarajan-lab/cehrgpt-public/actions/workflows/tests.yaml/badge.svg)](https://github.com/knatarajan-lab/cehrgpt-public/actions/workflows/tests.yml)
38
- [![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray)](https://github.com/knatarajan-lab/cehrgpt-public/blob/main/LICENSE)
39
- [![contributors](https://img.shields.io/github/contributors/knatarajan-lab/cehrgpt-public.svg)](https://github.com/knatarajan-lab/cehrgpt-public/graphs/contributors)
40
+ [![tests](https://github.com/knatarajan-lab/cehrgpt/actions/workflows/tests.yaml/badge.svg)](https://github.com/knatarajan-lab/cehrgpt/actions/workflows/tests.yaml)
41
+ [![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray)](https://github.com/knatarajan-lab/cehrgpt/blob/main/LICENSE)
42
+ [![contributors](https://img.shields.io/github/contributors/knatarajan-lab/cehrgpt.svg)](https://github.com/knatarajan-lab/cehrgpt/graphs/contributors)
40
43
 
41
44
  ## Description
42
45
  CEHRGPT is a synthetic data generation model developed to handle structured electronic health records (EHR) with enhanced privacy and reliability. It leverages state-of-the-art natural language processing techniques to create realistic, anonymized patient data that can be used for research and development without compromising patient privacy.
@@ -2,9 +2,9 @@
2
2
 
3
3
  [![PyPI - Version](https://img.shields.io/pypi/v/cehrgpt)](https://pypi.org/project/cehrgpt/)
4
4
  ![Python](https://img.shields.io/badge/-Python_3.11-blue?logo=python&logoColor=white)
5
- [![tests](https://github.com/knatarajan-lab/cehrgpt-public/actions/workflows/tests.yaml/badge.svg)](https://github.com/knatarajan-lab/cehrgpt-public/actions/workflows/tests.yml)
6
- [![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray)](https://github.com/knatarajan-lab/cehrgpt-public/blob/main/LICENSE)
7
- [![contributors](https://img.shields.io/github/contributors/knatarajan-lab/cehrgpt-public.svg)](https://github.com/knatarajan-lab/cehrgpt-public/graphs/contributors)
5
+ [![tests](https://github.com/knatarajan-lab/cehrgpt/actions/workflows/tests.yaml/badge.svg)](https://github.com/knatarajan-lab/cehrgpt/actions/workflows/tests.yaml)
6
+ [![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray)](https://github.com/knatarajan-lab/cehrgpt/blob/main/LICENSE)
7
+ [![contributors](https://img.shields.io/github/contributors/knatarajan-lab/cehrgpt.svg)](https://github.com/knatarajan-lab/cehrgpt/graphs/contributors)
8
8
 
9
9
  ## Description
10
10
  CEHRGPT is a synthetic data generation model developed to handle structured electronic health records (EHR) with enhanced privacy and reliability. It leverages state-of-the-art natural language processing techniques to create realistic, anonymized patient data that can be used for research and development without compromising patient privacy.
@@ -77,4 +77,4 @@ sh scripts/omop_pipeline.sh \
77
77
  author={Natarajan, K and others},
78
78
  journal={arXiv preprint arXiv:2402.04400},
79
79
  year={2024}
80
- }
80
+ }
@@ -28,13 +28,15 @@ classifiers = [
28
28
  ]
29
29
 
30
30
  dependencies = [
31
- "cehrbert==1.3.8",
31
+ "cehrbert==1.4.5",
32
+ "cehrbert_data==0.0.11",
32
33
  "openai==1.54.3",
33
34
  "optuna==4.0.0",
34
- "transformers==4.40.0",
35
+ "transformers==4.44.1",
35
36
  "tokenizers==0.19.0",
36
37
  "peft==0.10.0",
37
- "trl==0.11.4",
38
+ "lightgbm",
39
+ "polars",
38
40
  ]
39
41
 
40
42
  [tool.setuptools_scm]
@@ -29,7 +29,8 @@ python -u -m cehrbert_data.prediction_cohorts.cad_cabg_cohort \
29
29
  -dl 1985-01-01 -du 2023-12-31 \
30
30
  -l 18 -u 100 -ow 360 -ps 0 -pw 360 -f \
31
31
  --att_type cehr_bert \
32
- --ehr_table_list condition_occurrence procedure_occurrence drug_exposure -iv
32
+ --ehr_table_list condition_occurrence procedure_occurrence drug_exposure -iv \
33
+ --is_remove_index_prediction_starts
33
34
 
34
35
  # Run Predictions on CAD CABG
35
36
  echo "Run predictions on cad_cabg"
@@ -56,9 +57,10 @@ python -u -m cehrbert_data.prediction_cohorts.hf_readmission \
56
57
  -c hf_readmission_bow \
57
58
  -i "$OMOP_FOLDER" \
58
59
  -o "$OMOP_FOLDER/cohorts/hf_readmission" \
59
- -dl 1985-01-01 -du 2023-12-31 -l 18 -u 100 -ow 360 -ps 0 -pw 30 -f \
60
+ -dl 1985-01-01 -du 2023-12-31 -l 18 -u 100 -ow 360 -ps 1 -pw 30 -f \
60
61
  --att_type cehr_bert \
61
- --ehr_table_list condition_occurrence procedure_occurrence drug_exposure -iv
62
+ --ehr_table_list condition_occurrence procedure_occurrence drug_exposure -iv \
63
+ --is_remove_index_prediction_starts
62
64
 
63
65
  # Run predictions on HF Readmission
64
66
  echo "Run predictions on hf_readmission"
@@ -85,9 +87,10 @@ python -u -m cehrbert_data.prediction_cohorts.copd_readmission \
85
87
  -c copd_readmission_bow \
86
88
  -i "$OMOP_FOLDER" \
87
89
  -o "$OMOP_FOLDER/cohorts/copd_readmission" \
88
- -dl 1985-01-01 -du 2023-12-31 -l 18 -u 100 -ow 720 -ps 0 -pw 360 -f \
90
+ -dl 1985-01-01 -du 2023-12-31 -l 18 -u 100 -ow 360 -ps 1 -pw 30 -f \
89
91
  --att_type cehr_bert \
90
- --ehr_table_list condition_occurrence procedure_occurrence drug_exposure -iv
92
+ --ehr_table_list condition_occurrence procedure_occurrence drug_exposure -iv \
93
+ --is_remove_index_prediction_starts
91
94
 
92
95
  # Run predictions on COPD Readmission
93
96
  echo "Run predictions on copd_readmission"
@@ -145,7 +148,8 @@ python -u -m cehrbert_data.prediction_cohorts.afib_ischemic_stroke \
145
148
  -o "$OMOP_FOLDER/cohorts/afib_ischemic_stroke" \
146
149
  -dl 1985-01-01 -du 2023-12-31 -l 18 -u 100 -ow 720 -ps 0 -pw 360 -f \
147
150
  --att_type cehr_bert \
148
- --ehr_table_list condition_occurrence procedure_occurrence drug_exposure -iv
151
+ --ehr_table_list condition_occurrence procedure_occurrence drug_exposure -iv \
152
+ --is_remove_index_prediction_starts
149
153
 
150
154
  # Run predictions on AFIB Ischemic Stroke
151
155
  echo "Run predictions on afib_ischemic_stroke"
@@ -0,0 +1,260 @@
1
+ #!/bin/sh
2
+
3
+ # Function to display usage information
4
+ usage() {
5
+ echo "Usage: $0 [options]"
6
+ echo ""
7
+ echo "Options:"
8
+ echo " --base_dir=DIR Base directory containing cohorts (required)"
9
+ echo " --dataset_prepared_path=PATH Path to prepared dataset (required)"
10
+ echo " --model_path=PATH Path to pre-trained model and tokenizer (required)"
11
+ echo " --preprocessing_workers=NUM Number of preprocessing workers (required)"
12
+ echo " --batch_size=NUM Batch size for evaluation (required)"
13
+ echo " --output_dir=DIR Output directory for results (required)"
14
+ echo " --model_name=NAME Name for the model output directory (default: cehrgpt_model)"
15
+ echo " --max_tokens_per_batch=NUM Maximum tokens per batch (default: 16384)"
16
+ echo " --torch_type=TYPE Torch data type (default: float32)"
17
+ echo " --disable_sample_packing Disable sample packing (enabled by default)"
18
+ echo ""
19
+ echo "Example:"
20
+ echo " $0 --base_dir=/path/to/cohorts --dataset_prepared_path=/path/to/dataset_prepared \\"
21
+ echo " --model_path=/path/to/model --preprocessing_workers=16 --batch_size=64 \\"
22
+ echo " --output_dir=/path/to/outputs --model_name=my_model --torch_type=float16"
23
+ exit 1
24
+ }
25
+
26
+ # Default values
27
+ MODEL_NAME="cehrgpt_model"
28
+ MAX_TOKENS_PER_BATCH="16384"
29
+ TORCH_TYPE="bfloat16"
30
+ DISABLE_SAMPLE_PACKING="false"
31
+
32
+ # Parse command line arguments
33
+ for arg in "$@"; do
34
+ case $arg in
35
+ --base_dir=*)
36
+ BASE_DIR="${arg#*=}"
37
+ ;;
38
+ --dataset_prepared_path=*)
39
+ DATASET_PREPARED_PATH="${arg#*=}"
40
+ ;;
41
+ --model_path=*)
42
+ MODEL_PATH="${arg#*=}"
43
+ ;;
44
+ --preprocessing_workers=*)
45
+ PREPROCESSING_WORKERS="${arg#*=}"
46
+ ;;
47
+ --batch_size=*)
48
+ BATCH_SIZE="${arg#*=}"
49
+ ;;
50
+ --output_dir=*)
51
+ OUTPUT_DIR="${arg#*=}"
52
+ ;;
53
+ --model_name=*)
54
+ MODEL_NAME="${arg#*=}"
55
+ ;;
56
+ --max_tokens_per_batch=*)
57
+ MAX_TOKENS_PER_BATCH="${arg#*=}"
58
+ ;;
59
+ --torch_type=*)
60
+ TORCH_TYPE="${arg#*=}"
61
+ ;;
62
+ --disable_sample_packing)
63
+ DISABLE_SAMPLE_PACKING="true"
64
+ ;;
65
+ --help|-h)
66
+ usage
67
+ ;;
68
+ *)
69
+ echo "Error: Unknown option: $arg"
70
+ usage
71
+ ;;
72
+ esac
73
+ done
74
+
75
+ # Check for required arguments
76
+ if [ -z "$BASE_DIR" ] || [ -z "$DATASET_PREPARED_PATH" ] || [ -z "$MODEL_PATH" ] || [ -z "$PREPROCESSING_WORKERS" ] || [ -z "$BATCH_SIZE" ] || [ -z "$OUTPUT_DIR" ]; then
77
+ echo "Error: Missing required arguments"
78
+ usage
79
+ fi
80
+
81
+ # Validate arguments
82
+ if [ ! -d "$BASE_DIR" ]; then
83
+ echo "Error: Base directory does not exist: $BASE_DIR"
84
+ exit 1
85
+ fi
86
+
87
+ if [ ! -d "$DATASET_PREPARED_PATH" ]; then
88
+ echo "Error: Dataset prepared path does not exist: $DATASET_PREPARED_PATH"
89
+ exit 1
90
+ fi
91
+
92
+ if [ ! -d "$MODEL_PATH" ]; then
93
+ echo "Error: Model path does not exist: $MODEL_PATH"
94
+ exit 1
95
+ fi
96
+
97
+ # Create output directory if it doesn't exist
98
+ mkdir -p "$OUTPUT_DIR"
99
+
100
+ # Check if preprocessing workers is a number
101
+ if ! [ "$PREPROCESSING_WORKERS" -eq "$PREPROCESSING_WORKERS" ] 2>/dev/null; then
102
+ echo "Error: Preprocessing workers must be a number: $PREPROCESSING_WORKERS"
103
+ exit 1
104
+ fi
105
+
106
+ # Check if batch size is a number
107
+ if ! [ "$BATCH_SIZE" -eq "$BATCH_SIZE" ] 2>/dev/null; then
108
+ echo "Error: Batch size must be a number: $BATCH_SIZE"
109
+ exit 1
110
+ fi
111
+
112
+ # Check if max tokens per batch is a number
113
+ if ! [ "$MAX_TOKENS_PER_BATCH" -eq "$MAX_TOKENS_PER_BATCH" ] 2>/dev/null; then
114
+ echo "Error: Max tokens per batch must be a number: $MAX_TOKENS_PER_BATCH"
115
+ exit 1
116
+ fi
117
+
118
+ # Validate torch_type (common PyTorch data types)
119
+ case "$TORCH_TYPE" in
120
+ float16|float32|float64|bfloat16|int8|int16|int32|int64)
121
+ ;;
122
+ *)
123
+ echo "Error: Invalid torch_type. Supported types: float16, float32, float64, bfloat16, int8, int16, int32, int64"
124
+ exit 1
125
+ ;;
126
+ esac
127
+
128
+ # Validate disable_sample_packing is boolean-like
129
+ if [ "$DISABLE_SAMPLE_PACKING" != "true" ] && [ "$DISABLE_SAMPLE_PACKING" != "false" ]; then
130
+ echo "Error: disable_sample_packing must be 'true' or 'false': $DISABLE_SAMPLE_PACKING"
131
+ exit 1
132
+ fi
133
+
134
+ # Log file setup
135
+ LOG_DIR="$BASE_DIR/logs"
136
+ mkdir -p "$LOG_DIR"
137
+ TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
138
+ MAIN_LOG="$LOG_DIR/run_${TIMESTAMP}.log"
139
+
140
+ # Log function
141
+ log() {
142
+ message="[$(date '+%Y-%m-%d %H:%M:%S')] $1"
143
+ echo "$message" | tee -a "$MAIN_LOG"
144
+ }
145
+
146
+ # Main execution
147
+ log "Starting feature extraction and model training process"
148
+ log "Configuration:"
149
+ log " --base_dir=$BASE_DIR"
150
+ log " --dataset_prepared_path=$DATASET_PREPARED_PATH"
151
+ log " --model_path=$MODEL_PATH"
152
+ log " --preprocessing_workers=$PREPROCESSING_WORKERS"
153
+ log " --batch_size=$BATCH_SIZE"
154
+ log " --output_dir=$OUTPUT_DIR"
155
+ log " --model_name=$MODEL_NAME"
156
+ log " --max_tokens_per_batch=$MAX_TOKENS_PER_BATCH"
157
+ log " --torch_type=$TORCH_TYPE"
158
+ log " --disable_sample_packing=$DISABLE_SAMPLE_PACKING"
159
+
160
+ # Find valid cohorts and write to a temp file
161
+ TEMP_COHORT_LIST="$LOG_DIR/cohort_list_${TIMESTAMP}.txt"
162
+ > "$TEMP_COHORT_LIST" # Clear the file
163
+
164
+ # Find all valid cohorts (directories with train and test subdirectories)
165
+ for cohort_dir in "$BASE_DIR"/*; do
166
+ if [ -d "$cohort_dir" ] && [ -d "$cohort_dir/train" ] && [ -d "$cohort_dir/test" ]; then
167
+ cohort_name=$(basename "$cohort_dir")
168
+ echo "$cohort_name" >> "$TEMP_COHORT_LIST"
169
+ fi
170
+ done
171
+
172
+ # Check if any valid cohorts were found
173
+ if [ ! -s "$TEMP_COHORT_LIST" ]; then
174
+ log "ERROR: No valid cohorts found in $BASE_DIR"
175
+ rm -f "$TEMP_COHORT_LIST"
176
+ exit 1
177
+ fi
178
+
179
+ # Display all cohorts that will be processed
180
+ cohort_count=$(wc -l < "$TEMP_COHORT_LIST")
181
+ log "Found $cohort_count cohorts to process:"
182
+ while read -r cohort; do
183
+ log "- $cohort"
184
+ done < "$TEMP_COHORT_LIST"
185
+
186
+ # Process each cohort sequentially
187
+ while read -r cohort_name; do
188
+ cohort_dir="$OUTPUT_DIR/$cohort_name"
189
+ output_dir="$cohort_dir/$MODEL_NAME"
190
+
191
+ log "===================================================="
192
+ log "Processing cohort: $cohort_name"
193
+ log "===================================================="
194
+
195
+ cohort_log="$LOG_DIR/${cohort_name}_${TIMESTAMP}.log"
196
+
197
+ # Create output directory if it doesn't exist
198
+ mkdir -p "$output_dir"
199
+
200
+ # Prepare command for feature extraction
201
+ FEATURE_CMD="python -u -m cehrgpt.tools.linear_prob.compute_cehrgpt_features \
202
+ --data_folder \"$BASE_DIR/$cohort_name/train/\" \
203
+ --test_data_folder \"$BASE_DIR/$cohort_name/test/\" \
204
+ --dataset_prepared_path \"$DATASET_PREPARED_PATH\" \
205
+ --model_name_or_path \"$MODEL_PATH\" \
206
+ --tokenizer_name_or_path \"$MODEL_PATH\" \
207
+ --output_dir \"$output_dir\" \
208
+ --preprocessing_num_workers \"$PREPROCESSING_WORKERS\" \
209
+ --per_device_eval_batch_size \"$BATCH_SIZE\" \
210
+ --max_tokens_per_batch \"$MAX_TOKENS_PER_BATCH\" \
211
+ --torch_type \"$TORCH_TYPE\""
212
+
213
+ # Add sample packing flag if not disabled
214
+ if [ "$DISABLE_SAMPLE_PACKING" = "false" ]; then
215
+ FEATURE_CMD="$FEATURE_CMD --sample_packing"
216
+ fi
217
+
218
+ # Step 1: Feature extraction
219
+ log "Starting feature extraction for $cohort_name..."
220
+ log "Command: $FEATURE_CMD"
221
+
222
+ eval "$FEATURE_CMD > \"$cohort_log\" 2>&1"
223
+
224
+ feature_extraction_status=$?
225
+ if [ $feature_extraction_status -ne 0 ]; then
226
+ log "ERROR: Feature extraction failed for $cohort_name. Check $cohort_log for details."
227
+ continue
228
+ fi
229
+
230
+ # Step 2: Model training
231
+ log "Starting model training for $cohort_name..."
232
+ log "Command: python -u -m cehrgpt.tools.linear_prob.train_with_cehrgpt_features --features_data_dir $output_dir --output_dir $output_dir"
233
+
234
+ python -u -m cehrgpt.tools.linear_prob.train_with_cehrgpt_features \
235
+ --features_data_dir "$output_dir" \
236
+ --output_dir "$output_dir" \
237
+ >> "$cohort_log" 2>&1
238
+
239
+ echo "Running meds-evaluation for logistic regression for $TASK_NAME..."
240
+ meds-evaluation-cli predictions_path="$output_dir/logistic/test_predictions" \
241
+ output_dir="$output_dir/logistic/"
242
+
243
+ # Check if the second command succeeded
244
+ if [ $? -ne 0 ]; then
245
+ echo "Error: Running meds-evaluation failed for logistic regression for task $TASK_NAME"
246
+ fi
247
+
248
+ model_training_status=$?
249
+ if [ $model_training_status -ne 0 ]; then
250
+ log "ERROR: Model training failed for $cohort_name. Check $cohort_log for details."
251
+ continue
252
+ fi
253
+
254
+ log "Successfully completed processing for $cohort_name"
255
+ done < "$TEMP_COHORT_LIST"
256
+
257
+ # Clean up
258
+ rm -f "$TEMP_COHORT_LIST"
259
+
260
+ log "All processing completed"
@@ -0,0 +1,36 @@
1
+ import os
2
+
3
+ import polars as pl
4
+
5
+ from cehrgpt.gpt_utils import extract_time_interval_in_days, is_att_token
6
+
7
+
8
+ def main(args):
9
+ dataset = pl.read_parquet(os.path.join(args.input_dir, "*.parquet"))
10
+ time_token_frequency_df = (
11
+ dataset.select(pl.col("concept_ids").explode().alias("concept_id"))
12
+ .filter(pl.col("concept_id").map_elements(is_att_token))
13
+ .with_columns(
14
+ pl.col("concept_id")
15
+ .map_elements(extract_time_interval_in_days)
16
+ .alias("time_interval")
17
+ )
18
+ )
19
+ results = time_token_frequency_df.select(
20
+ pl.mean("time_interval").alias("mean"), pl.std("time_interval").alias("std")
21
+ ).to_dicts()[0]
22
+ print(results)
23
+
24
+
25
+ if __name__ == "__main__":
26
+ import argparse
27
+
28
+ parser = argparse.ArgumentParser(description="EHR Irregularity analysis")
29
+ parser.add_argument(
30
+ "--input_dir",
31
+ dest="input_dir",
32
+ action="store",
33
+ help="The path for where the input data folder",
34
+ required=True,
35
+ )
36
+ main(parser.parse_args())
@@ -1,9 +1,10 @@
1
- from typing import Union
1
+ from typing import Optional, Union
2
2
 
3
3
  from cehrbert.data_generators.hf_data_generator.hf_dataset import (
4
4
  FINETUNING_COLUMNS,
5
5
  apply_cehrbert_dataset_mapping,
6
6
  )
7
+ from cehrbert.data_generators.hf_data_generator.meds_utils import CacheFileCollector
7
8
  from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
8
9
  from datasets import Dataset, DatasetDict
9
10
 
@@ -22,6 +23,7 @@ CEHRGPT_COLUMNS = [
22
23
  "num_of_visits",
23
24
  "values",
24
25
  "value_indicators",
26
+ "epoch_times",
25
27
  ]
26
28
 
27
29
  TRANSFORMER_COLUMNS = ["input_ids"]
@@ -31,16 +33,25 @@ def create_cehrgpt_pretraining_dataset(
31
33
  dataset: Union[Dataset, DatasetDict],
32
34
  cehrgpt_tokenizer: CehrGptTokenizer,
33
35
  data_args: DataTrainingArguments,
34
- ) -> Dataset:
36
+ cache_file_collector: Optional[CacheFileCollector] = None,
37
+ ) -> Union[Dataset, DatasetDict]:
35
38
  required_columns = TRANSFORMER_COLUMNS + CEHRGPT_COLUMNS
39
+ # TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
40
+ if not data_args.streaming:
41
+ if isinstance(dataset, DatasetDict):
42
+ all_columns = dataset["train"].column_names
43
+ else:
44
+ all_columns = dataset.column_names
45
+ if "visit_concept_ids" in all_columns:
46
+ dataset.remove_columns(["visit_concept_ids"])
36
47
  dataset = apply_cehrbert_dataset_mapping(
37
48
  dataset,
38
49
  HFCehrGptTokenizationMapping(cehrgpt_tokenizer),
39
50
  num_proc=data_args.preprocessing_num_workers,
40
51
  batch_size=data_args.preprocessing_batch_size,
41
52
  streaming=data_args.streaming,
53
+ cache_file_collector=cache_file_collector,
42
54
  )
43
-
44
55
  if not data_args.streaming:
45
56
  if isinstance(dataset, DatasetDict):
46
57
  all_columns = dataset["train"].column_names
@@ -56,8 +67,17 @@ def create_cehrgpt_finetuning_dataset(
56
67
  dataset: Union[Dataset, DatasetDict],
57
68
  cehrgpt_tokenizer: CehrGptTokenizer,
58
69
  data_args: DataTrainingArguments,
59
- ) -> Dataset:
70
+ cache_file_collector: Optional[CacheFileCollector] = None,
71
+ ) -> Union[Dataset, DatasetDict]:
60
72
  required_columns = TRANSFORMER_COLUMNS + CEHRGPT_COLUMNS + FINETUNING_COLUMNS
73
+ # TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
74
+ if not data_args.streaming:
75
+ if isinstance(dataset, DatasetDict):
76
+ all_columns = dataset["train"].column_names
77
+ else:
78
+ all_columns = dataset.column_names
79
+ if "visit_concept_ids" in all_columns:
80
+ dataset.remove_columns(["visit_concept_ids"])
61
81
  mapping_functions = [
62
82
  HFFineTuningMapping(cehrgpt_tokenizer),
63
83
  ]
@@ -68,6 +88,7 @@ def create_cehrgpt_finetuning_dataset(
68
88
  num_proc=data_args.preprocessing_num_workers,
69
89
  batch_size=data_args.preprocessing_batch_size,
70
90
  streaming=data_args.streaming,
91
+ cache_file_collector=cache_file_collector,
71
92
  )
72
93
 
73
94
  if not data_args.streaming: