cehrgpt 0.0.2__tar.gz → 0.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/.gitignore +3 -0
- {cehrgpt-0.0.2/src/cehrgpt.egg-info → cehrgpt-0.1.1}/PKG-INFO +11 -8
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/README.md +4 -4
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/pyproject.toml +5 -3
- cehrgpt-0.1.1/sample_data/omop_vocab/concept/concept.parquet +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/scripts/level_three_evaluation.sh +10 -6
- cehrgpt-0.1.1/scripts/run_linear_prob.sh +260 -0
- cehrgpt-0.1.1/src/cehrgpt/analysis/irregularity.py +36 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/data/hf_cehrgpt_dataset.py +25 -4
- cehrgpt-0.1.1/src/cehrgpt/data/hf_cehrgpt_dataset_collator.py +1020 -0
- cehrgpt-0.1.1/src/cehrgpt/data/hf_cehrgpt_dataset_mapping.py +595 -0
- cehrgpt-0.1.1/src/cehrgpt/data/sample_packing_sampler.py +181 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/generation/omop_converter_batch.py +32 -2
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/gpt_utils.py +20 -2
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/models/config.py +35 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/models/hf_cehrgpt.py +470 -106
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/models/hf_modeling_outputs.py +1 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/models/special_tokens.py +1 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/models/tokenization_hf_cehrgpt.py +358 -71
- cehrgpt-0.1.1/src/cehrgpt/runners/data_utils.py +358 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/runners/gpt_runner_util.py +0 -10
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/runners/hf_cehrgpt_finetune_runner.py +181 -283
- cehrgpt-0.1.1/src/cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +582 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +90 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/runners/hyperparameter_search_util.py +10 -8
- cehrgpt-0.1.1/src/cehrgpt/runners/sample_packing_trainer.py +185 -0
- cehrgpt-0.1.1/src/cehrgpt/simulations/generate_plots.py +95 -0
- cehrgpt-0.1.1/src/cehrgpt/simulations/run_simulation.sh +24 -0
- cehrgpt-0.1.1/src/cehrgpt/simulations/time_embedding_simulation.py +250 -0
- cehrgpt-0.1.1/src/cehrgpt/simulations/time_token_simulation.py +177 -0
- cehrgpt-0.1.1/src/cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/time_to_event/time_to_event_model.py +2 -13
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
- cehrgpt-0.1.1/src/cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +495 -0
- cehrgpt-0.1.1/src/cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1/src/cehrgpt.egg-info}/PKG-INFO +11 -8
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt.egg-info/SOURCES.txt +17 -9
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt.egg-info/requires.txt +5 -3
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/tests/integration_tests/runners/hf_cehrgpt_pretrain_runner_test.py +26 -11
- cehrgpt-0.1.1/tests/integration_tests/runners/hf_cehrgpt_pretrain_sample_packing_runner_test.py +122 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/tests/integration_tests/runners/hf_cehrgpt_pretrain_sfm_runner_test.py +9 -3
- cehrgpt-0.1.1/tests/unit_tests/models/model_utils_test.py +131 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/tests/unit_tests/runners/hf_cehrgpt_finetune_runner_test.py +4 -4
- cehrgpt-0.1.1/tests/unit_tests/tools/__init__.py +0 -0
- cehrgpt-0.0.2/src/cehrgpt/data/hf_cehrgpt_dataset_collator.py +0 -482
- cehrgpt-0.0.2/src/cehrgpt/data/hf_cehrgpt_dataset_mapping.py +0 -382
- cehrgpt-0.0.2/src/cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
- cehrgpt-0.0.2/src/cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
- cehrgpt-0.0.2/src/cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
- cehrgpt-0.0.2/src/cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
- cehrgpt-0.0.2/src/cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
- cehrgpt-0.0.2/src/cehrgpt/rl_finetune/ppo_finetune.py +0 -394
- cehrgpt-0.0.2/src/cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
- cehrgpt-0.0.2/src/cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
- cehrgpt-0.0.2/src/cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +0 -406
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/.github/workflows/build-python.yaml +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/.github/workflows/tests.yaml +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/.pre-commit-config.yaml +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/LICENSE +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/sample_configs/cehrgpt_pretrain_sample_config.yaml +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/sample_data/pretrain/patient_sequence.parquet +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/sample_data/pretrained_embeddings/pretrained_embedding_concepts.pkl +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/sample_data/pretrained_embeddings/pretrained_embedding_vectors.npy +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/scripts/omop_pipeline.sh +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/scripts/pool_generated_sequences.sh +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/setup.cfg +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/__init__.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/__init__.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/analysis/__init__.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/analysis/privacy/__init__.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/analysis/privacy/attribute_inference.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/analysis/privacy/attribute_inference_config.yml +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/analysis/privacy/member_inference.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/analysis/privacy/nearest_neighbor_inference.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/analysis/privacy/reid_inference.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/analysis/privacy/utils.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/cehrgpt_args.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/data/__init__.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/generation/__init__.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/generation/chatgpt_generation.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/generation/omop_entity.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/models/__init__.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/models/pretrained_embeddings.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/omop/__init__.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/omop/condition_era.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/omop/observation_period.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/omop/omop_argparse.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/omop/omop_table_builder.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/omop/queries/__init__.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/omop/queries/condition_era.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/omop/queries/observation_period.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/omop/sample_omop_tables.py +0 -0
- {cehrgpt-0.0.2/src/cehrgpt/rl_finetune → cehrgpt-0.1.1/src/cehrgpt/runners}/__init__.py +0 -0
- {cehrgpt-0.0.2/src/cehrgpt/runners → cehrgpt-0.1.1/src/cehrgpt/simulations}/__init__.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/time_to_event/__init__.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/time_to_event/config/30_day_readmission.yaml +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/time_to_event/config/t2dm_hf.yaml +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/time_to_event/time_to_event_utils.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/tools/__init__.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/tools/ehrshot_benchmark.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/tools/generate_causal_patient_split_by_age.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/tools/generate_pretrained_embeddings.py +0 -0
- {cehrgpt-0.0.2/tests → cehrgpt-0.1.1/src/cehrgpt/tools/linear_prob}/__init__.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/tools/merge_synthetic_real_dataasets.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt/tools/upload_omop_tables.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt.egg-info/dependency_links.txt +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/src/cehrgpt.egg-info/top_level.txt +0 -0
- {cehrgpt-0.0.2/tests/integration_tests → cehrgpt-0.1.1/tests}/__init__.py +0 -0
- {cehrgpt-0.0.2/tests/integration_tests/runners → cehrgpt-0.1.1/tests/integration_tests}/__init__.py +0 -0
- {cehrgpt-0.0.2/tests/unit_tests → cehrgpt-0.1.1/tests/integration_tests/runners}/__init__.py +0 -0
- {cehrgpt-0.0.2/tests/unit_tests/models → cehrgpt-0.1.1/tests/unit_tests}/__init__.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/tests/unit_tests/gpt_utils_test.py +0 -0
- {cehrgpt-0.0.2/tests/unit_tests/models/tokenization → cehrgpt-0.1.1/tests/unit_tests/models}/__init__.py +0 -0
- {cehrgpt-0.0.2/tests/unit_tests/runners → cehrgpt-0.1.1/tests/unit_tests/models/tokenization}/__init__.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/tests/unit_tests/models/tokenization/create_bins_with_spline_test.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/tests/unit_tests/models/tokenization/create_sample_from_bins_test.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/tests/unit_tests/numeric_concept_statistics_test.py +0 -0
- {cehrgpt-0.0.2/tests/unit_tests/tools → cehrgpt-0.1.1/tests/unit_tests/runners}/__init__.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/tests/unit_tests/tokenization_test.py +0 -0
- {cehrgpt-0.0.2 → cehrgpt-0.1.1}/tests/unit_tests/tools/upload_omop_tables_test.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: cehrgpt
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.1.1
|
4
4
|
Summary: CEHR-GPT: Generating Electronic Health Records with Chronological Patient Timelines
|
5
5
|
Author-email: Chao Pang <chaopang229@gmail.com>, Xinzhuo Jiang <xj2193@cumc.columbia.edu>, Krishna Kalluri <kk3326@cumc.columbia.edu>, Elise Minto <em3697@cumc.columbia.edu>, Jason Patterson <jp3477@cumc.columbia.edu>, Nishanth Parameshwar Pavinkurve <np2689@cumc.columbia.edu>, Karthik Natarajan <kn2174@cumc.columbia.edu>
|
6
6
|
License: MIT License
|
@@ -12,13 +12,15 @@ Classifier: Programming Language :: Python :: 3
|
|
12
12
|
Requires-Python: >=3.10.0
|
13
13
|
Description-Content-Type: text/markdown
|
14
14
|
License-File: LICENSE
|
15
|
-
Requires-Dist: cehrbert==1.
|
15
|
+
Requires-Dist: cehrbert==1.4.5
|
16
|
+
Requires-Dist: cehrbert_data==0.0.11
|
16
17
|
Requires-Dist: openai==1.54.3
|
17
18
|
Requires-Dist: optuna==4.0.0
|
18
|
-
Requires-Dist: transformers==4.
|
19
|
+
Requires-Dist: transformers==4.44.1
|
19
20
|
Requires-Dist: tokenizers==0.19.0
|
20
21
|
Requires-Dist: peft==0.10.0
|
21
|
-
Requires-Dist:
|
22
|
+
Requires-Dist: lightgbm
|
23
|
+
Requires-Dist: polars
|
22
24
|
Provides-Extra: dev
|
23
25
|
Requires-Dist: pre-commit; extra == "dev"
|
24
26
|
Requires-Dist: pytest; extra == "dev"
|
@@ -29,14 +31,15 @@ Requires-Dist: hypothesis; extra == "dev"
|
|
29
31
|
Requires-Dist: black; extra == "dev"
|
30
32
|
Provides-Extra: flash-attn
|
31
33
|
Requires-Dist: flash_attn; extra == "flash-attn"
|
34
|
+
Dynamic: license-file
|
32
35
|
|
33
36
|
# CEHRGPT
|
34
37
|
|
35
38
|
[](https://pypi.org/project/cehrgpt/)
|
36
39
|

|
37
|
-
[](https://github.com/knatarajan-lab/cehrgpt
|
39
|
-
[](https://github.com/knatarajan-lab/cehrgpt/actions/workflows/tests.yaml)
|
41
|
+
[](https://github.com/knatarajan-lab/cehrgpt/blob/main/LICENSE)
|
42
|
+
[](https://github.com/knatarajan-lab/cehrgpt/graphs/contributors)
|
40
43
|
|
41
44
|
## Description
|
42
45
|
CEHRGPT is a synthetic data generation model developed to handle structured electronic health records (EHR) with enhanced privacy and reliability. It leverages state-of-the-art natural language processing techniques to create realistic, anonymized patient data that can be used for research and development without compromising patient privacy.
|
@@ -2,9 +2,9 @@
|
|
2
2
|
|
3
3
|
[](https://pypi.org/project/cehrgpt/)
|
4
4
|

|
5
|
-
[](https://github.com/knatarajan-lab/cehrgpt
|
7
|
-
[](https://github.com/knatarajan-lab/cehrgpt/actions/workflows/tests.yaml)
|
6
|
+
[](https://github.com/knatarajan-lab/cehrgpt/blob/main/LICENSE)
|
7
|
+
[](https://github.com/knatarajan-lab/cehrgpt/graphs/contributors)
|
8
8
|
|
9
9
|
## Description
|
10
10
|
CEHRGPT is a synthetic data generation model developed to handle structured electronic health records (EHR) with enhanced privacy and reliability. It leverages state-of-the-art natural language processing techniques to create realistic, anonymized patient data that can be used for research and development without compromising patient privacy.
|
@@ -77,4 +77,4 @@ sh scripts/omop_pipeline.sh \
|
|
77
77
|
author={Natarajan, K and others},
|
78
78
|
journal={arXiv preprint arXiv:2402.04400},
|
79
79
|
year={2024}
|
80
|
-
}
|
80
|
+
}
|
@@ -28,13 +28,15 @@ classifiers = [
|
|
28
28
|
]
|
29
29
|
|
30
30
|
dependencies = [
|
31
|
-
"cehrbert==1.
|
31
|
+
"cehrbert==1.4.5",
|
32
|
+
"cehrbert_data==0.0.11",
|
32
33
|
"openai==1.54.3",
|
33
34
|
"optuna==4.0.0",
|
34
|
-
"transformers==4.
|
35
|
+
"transformers==4.44.1",
|
35
36
|
"tokenizers==0.19.0",
|
36
37
|
"peft==0.10.0",
|
37
|
-
"
|
38
|
+
"lightgbm",
|
39
|
+
"polars",
|
38
40
|
]
|
39
41
|
|
40
42
|
[tool.setuptools_scm]
|
Binary file
|
@@ -29,7 +29,8 @@ python -u -m cehrbert_data.prediction_cohorts.cad_cabg_cohort \
|
|
29
29
|
-dl 1985-01-01 -du 2023-12-31 \
|
30
30
|
-l 18 -u 100 -ow 360 -ps 0 -pw 360 -f \
|
31
31
|
--att_type cehr_bert \
|
32
|
-
--ehr_table_list condition_occurrence procedure_occurrence drug_exposure -iv
|
32
|
+
--ehr_table_list condition_occurrence procedure_occurrence drug_exposure -iv \
|
33
|
+
--is_remove_index_prediction_starts
|
33
34
|
|
34
35
|
# Run Predictions on CAD CABG
|
35
36
|
echo "Run predictions on cad_cabg"
|
@@ -56,9 +57,10 @@ python -u -m cehrbert_data.prediction_cohorts.hf_readmission \
|
|
56
57
|
-c hf_readmission_bow \
|
57
58
|
-i "$OMOP_FOLDER" \
|
58
59
|
-o "$OMOP_FOLDER/cohorts/hf_readmission" \
|
59
|
-
-dl 1985-01-01 -du 2023-12-31 -l 18 -u 100 -ow 360 -ps
|
60
|
+
-dl 1985-01-01 -du 2023-12-31 -l 18 -u 100 -ow 360 -ps 1 -pw 30 -f \
|
60
61
|
--att_type cehr_bert \
|
61
|
-
--ehr_table_list condition_occurrence procedure_occurrence drug_exposure -iv
|
62
|
+
--ehr_table_list condition_occurrence procedure_occurrence drug_exposure -iv \
|
63
|
+
--is_remove_index_prediction_starts
|
62
64
|
|
63
65
|
# Run predictions on HF Readmission
|
64
66
|
echo "Run predictions on hf_readmission"
|
@@ -85,9 +87,10 @@ python -u -m cehrbert_data.prediction_cohorts.copd_readmission \
|
|
85
87
|
-c copd_readmission_bow \
|
86
88
|
-i "$OMOP_FOLDER" \
|
87
89
|
-o "$OMOP_FOLDER/cohorts/copd_readmission" \
|
88
|
-
-dl 1985-01-01 -du 2023-12-31 -l 18 -u 100 -ow
|
90
|
+
-dl 1985-01-01 -du 2023-12-31 -l 18 -u 100 -ow 360 -ps 1 -pw 30 -f \
|
89
91
|
--att_type cehr_bert \
|
90
|
-
--ehr_table_list condition_occurrence procedure_occurrence drug_exposure -iv
|
92
|
+
--ehr_table_list condition_occurrence procedure_occurrence drug_exposure -iv \
|
93
|
+
--is_remove_index_prediction_starts
|
91
94
|
|
92
95
|
# Run predictions on COPD Readmission
|
93
96
|
echo "Run predictions on copd_readmission"
|
@@ -145,7 +148,8 @@ python -u -m cehrbert_data.prediction_cohorts.afib_ischemic_stroke \
|
|
145
148
|
-o "$OMOP_FOLDER/cohorts/afib_ischemic_stroke" \
|
146
149
|
-dl 1985-01-01 -du 2023-12-31 -l 18 -u 100 -ow 720 -ps 0 -pw 360 -f \
|
147
150
|
--att_type cehr_bert \
|
148
|
-
--ehr_table_list condition_occurrence procedure_occurrence drug_exposure -iv
|
151
|
+
--ehr_table_list condition_occurrence procedure_occurrence drug_exposure -iv \
|
152
|
+
--is_remove_index_prediction_starts
|
149
153
|
|
150
154
|
# Run predictions on AFIB Ischemic Stroke
|
151
155
|
echo "Run predictions on afib_ischemic_stroke"
|
@@ -0,0 +1,260 @@
|
|
1
|
+
#!/bin/sh
|
2
|
+
|
3
|
+
# Function to display usage information
|
4
|
+
usage() {
|
5
|
+
echo "Usage: $0 [options]"
|
6
|
+
echo ""
|
7
|
+
echo "Options:"
|
8
|
+
echo " --base_dir=DIR Base directory containing cohorts (required)"
|
9
|
+
echo " --dataset_prepared_path=PATH Path to prepared dataset (required)"
|
10
|
+
echo " --model_path=PATH Path to pre-trained model and tokenizer (required)"
|
11
|
+
echo " --preprocessing_workers=NUM Number of preprocessing workers (required)"
|
12
|
+
echo " --batch_size=NUM Batch size for evaluation (required)"
|
13
|
+
echo " --output_dir=DIR Output directory for results (required)"
|
14
|
+
echo " --model_name=NAME Name for the model output directory (default: cehrgpt_model)"
|
15
|
+
echo " --max_tokens_per_batch=NUM Maximum tokens per batch (default: 16384)"
|
16
|
+
echo " --torch_type=TYPE Torch data type (default: float32)"
|
17
|
+
echo " --disable_sample_packing Disable sample packing (enabled by default)"
|
18
|
+
echo ""
|
19
|
+
echo "Example:"
|
20
|
+
echo " $0 --base_dir=/path/to/cohorts --dataset_prepared_path=/path/to/dataset_prepared \\"
|
21
|
+
echo " --model_path=/path/to/model --preprocessing_workers=16 --batch_size=64 \\"
|
22
|
+
echo " --output_dir=/path/to/outputs --model_name=my_model --torch_type=float16"
|
23
|
+
exit 1
|
24
|
+
}
|
25
|
+
|
26
|
+
# Default values
|
27
|
+
MODEL_NAME="cehrgpt_model"
|
28
|
+
MAX_TOKENS_PER_BATCH="16384"
|
29
|
+
TORCH_TYPE="bfloat16"
|
30
|
+
DISABLE_SAMPLE_PACKING="false"
|
31
|
+
|
32
|
+
# Parse command line arguments
|
33
|
+
for arg in "$@"; do
|
34
|
+
case $arg in
|
35
|
+
--base_dir=*)
|
36
|
+
BASE_DIR="${arg#*=}"
|
37
|
+
;;
|
38
|
+
--dataset_prepared_path=*)
|
39
|
+
DATASET_PREPARED_PATH="${arg#*=}"
|
40
|
+
;;
|
41
|
+
--model_path=*)
|
42
|
+
MODEL_PATH="${arg#*=}"
|
43
|
+
;;
|
44
|
+
--preprocessing_workers=*)
|
45
|
+
PREPROCESSING_WORKERS="${arg#*=}"
|
46
|
+
;;
|
47
|
+
--batch_size=*)
|
48
|
+
BATCH_SIZE="${arg#*=}"
|
49
|
+
;;
|
50
|
+
--output_dir=*)
|
51
|
+
OUTPUT_DIR="${arg#*=}"
|
52
|
+
;;
|
53
|
+
--model_name=*)
|
54
|
+
MODEL_NAME="${arg#*=}"
|
55
|
+
;;
|
56
|
+
--max_tokens_per_batch=*)
|
57
|
+
MAX_TOKENS_PER_BATCH="${arg#*=}"
|
58
|
+
;;
|
59
|
+
--torch_type=*)
|
60
|
+
TORCH_TYPE="${arg#*=}"
|
61
|
+
;;
|
62
|
+
--disable_sample_packing)
|
63
|
+
DISABLE_SAMPLE_PACKING="true"
|
64
|
+
;;
|
65
|
+
--help|-h)
|
66
|
+
usage
|
67
|
+
;;
|
68
|
+
*)
|
69
|
+
echo "Error: Unknown option: $arg"
|
70
|
+
usage
|
71
|
+
;;
|
72
|
+
esac
|
73
|
+
done
|
74
|
+
|
75
|
+
# Check for required arguments
|
76
|
+
if [ -z "$BASE_DIR" ] || [ -z "$DATASET_PREPARED_PATH" ] || [ -z "$MODEL_PATH" ] || [ -z "$PREPROCESSING_WORKERS" ] || [ -z "$BATCH_SIZE" ] || [ -z "$OUTPUT_DIR" ]; then
|
77
|
+
echo "Error: Missing required arguments"
|
78
|
+
usage
|
79
|
+
fi
|
80
|
+
|
81
|
+
# Validate arguments
|
82
|
+
if [ ! -d "$BASE_DIR" ]; then
|
83
|
+
echo "Error: Base directory does not exist: $BASE_DIR"
|
84
|
+
exit 1
|
85
|
+
fi
|
86
|
+
|
87
|
+
if [ ! -d "$DATASET_PREPARED_PATH" ]; then
|
88
|
+
echo "Error: Dataset prepared path does not exist: $DATASET_PREPARED_PATH"
|
89
|
+
exit 1
|
90
|
+
fi
|
91
|
+
|
92
|
+
if [ ! -d "$MODEL_PATH" ]; then
|
93
|
+
echo "Error: Model path does not exist: $MODEL_PATH"
|
94
|
+
exit 1
|
95
|
+
fi
|
96
|
+
|
97
|
+
# Create output directory if it doesn't exist
|
98
|
+
mkdir -p "$OUTPUT_DIR"
|
99
|
+
|
100
|
+
# Check if preprocessing workers is a number
|
101
|
+
if ! [ "$PREPROCESSING_WORKERS" -eq "$PREPROCESSING_WORKERS" ] 2>/dev/null; then
|
102
|
+
echo "Error: Preprocessing workers must be a number: $PREPROCESSING_WORKERS"
|
103
|
+
exit 1
|
104
|
+
fi
|
105
|
+
|
106
|
+
# Check if batch size is a number
|
107
|
+
if ! [ "$BATCH_SIZE" -eq "$BATCH_SIZE" ] 2>/dev/null; then
|
108
|
+
echo "Error: Batch size must be a number: $BATCH_SIZE"
|
109
|
+
exit 1
|
110
|
+
fi
|
111
|
+
|
112
|
+
# Check if max tokens per batch is a number
|
113
|
+
if ! [ "$MAX_TOKENS_PER_BATCH" -eq "$MAX_TOKENS_PER_BATCH" ] 2>/dev/null; then
|
114
|
+
echo "Error: Max tokens per batch must be a number: $MAX_TOKENS_PER_BATCH"
|
115
|
+
exit 1
|
116
|
+
fi
|
117
|
+
|
118
|
+
# Validate torch_type (common PyTorch data types)
|
119
|
+
case "$TORCH_TYPE" in
|
120
|
+
float16|float32|float64|bfloat16|int8|int16|int32|int64)
|
121
|
+
;;
|
122
|
+
*)
|
123
|
+
echo "Error: Invalid torch_type. Supported types: float16, float32, float64, bfloat16, int8, int16, int32, int64"
|
124
|
+
exit 1
|
125
|
+
;;
|
126
|
+
esac
|
127
|
+
|
128
|
+
# Validate disable_sample_packing is boolean-like
|
129
|
+
if [ "$DISABLE_SAMPLE_PACKING" != "true" ] && [ "$DISABLE_SAMPLE_PACKING" != "false" ]; then
|
130
|
+
echo "Error: disable_sample_packing must be 'true' or 'false': $DISABLE_SAMPLE_PACKING"
|
131
|
+
exit 1
|
132
|
+
fi
|
133
|
+
|
134
|
+
# Log file setup
|
135
|
+
LOG_DIR="$BASE_DIR/logs"
|
136
|
+
mkdir -p "$LOG_DIR"
|
137
|
+
TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
|
138
|
+
MAIN_LOG="$LOG_DIR/run_${TIMESTAMP}.log"
|
139
|
+
|
140
|
+
# Log function
|
141
|
+
log() {
|
142
|
+
message="[$(date '+%Y-%m-%d %H:%M:%S')] $1"
|
143
|
+
echo "$message" | tee -a "$MAIN_LOG"
|
144
|
+
}
|
145
|
+
|
146
|
+
# Main execution
|
147
|
+
log "Starting feature extraction and model training process"
|
148
|
+
log "Configuration:"
|
149
|
+
log " --base_dir=$BASE_DIR"
|
150
|
+
log " --dataset_prepared_path=$DATASET_PREPARED_PATH"
|
151
|
+
log " --model_path=$MODEL_PATH"
|
152
|
+
log " --preprocessing_workers=$PREPROCESSING_WORKERS"
|
153
|
+
log " --batch_size=$BATCH_SIZE"
|
154
|
+
log " --output_dir=$OUTPUT_DIR"
|
155
|
+
log " --model_name=$MODEL_NAME"
|
156
|
+
log " --max_tokens_per_batch=$MAX_TOKENS_PER_BATCH"
|
157
|
+
log " --torch_type=$TORCH_TYPE"
|
158
|
+
log " --disable_sample_packing=$DISABLE_SAMPLE_PACKING"
|
159
|
+
|
160
|
+
# Find valid cohorts and write to a temp file
|
161
|
+
TEMP_COHORT_LIST="$LOG_DIR/cohort_list_${TIMESTAMP}.txt"
|
162
|
+
> "$TEMP_COHORT_LIST" # Clear the file
|
163
|
+
|
164
|
+
# Find all valid cohorts (directories with train and test subdirectories)
|
165
|
+
for cohort_dir in "$BASE_DIR"/*; do
|
166
|
+
if [ -d "$cohort_dir" ] && [ -d "$cohort_dir/train" ] && [ -d "$cohort_dir/test" ]; then
|
167
|
+
cohort_name=$(basename "$cohort_dir")
|
168
|
+
echo "$cohort_name" >> "$TEMP_COHORT_LIST"
|
169
|
+
fi
|
170
|
+
done
|
171
|
+
|
172
|
+
# Check if any valid cohorts were found
|
173
|
+
if [ ! -s "$TEMP_COHORT_LIST" ]; then
|
174
|
+
log "ERROR: No valid cohorts found in $BASE_DIR"
|
175
|
+
rm -f "$TEMP_COHORT_LIST"
|
176
|
+
exit 1
|
177
|
+
fi
|
178
|
+
|
179
|
+
# Display all cohorts that will be processed
|
180
|
+
cohort_count=$(wc -l < "$TEMP_COHORT_LIST")
|
181
|
+
log "Found $cohort_count cohorts to process:"
|
182
|
+
while read -r cohort; do
|
183
|
+
log "- $cohort"
|
184
|
+
done < "$TEMP_COHORT_LIST"
|
185
|
+
|
186
|
+
# Process each cohort sequentially
|
187
|
+
while read -r cohort_name; do
|
188
|
+
cohort_dir="$OUTPUT_DIR/$cohort_name"
|
189
|
+
output_dir="$cohort_dir/$MODEL_NAME"
|
190
|
+
|
191
|
+
log "===================================================="
|
192
|
+
log "Processing cohort: $cohort_name"
|
193
|
+
log "===================================================="
|
194
|
+
|
195
|
+
cohort_log="$LOG_DIR/${cohort_name}_${TIMESTAMP}.log"
|
196
|
+
|
197
|
+
# Create output directory if it doesn't exist
|
198
|
+
mkdir -p "$output_dir"
|
199
|
+
|
200
|
+
# Prepare command for feature extraction
|
201
|
+
FEATURE_CMD="python -u -m cehrgpt.tools.linear_prob.compute_cehrgpt_features \
|
202
|
+
--data_folder \"$BASE_DIR/$cohort_name/train/\" \
|
203
|
+
--test_data_folder \"$BASE_DIR/$cohort_name/test/\" \
|
204
|
+
--dataset_prepared_path \"$DATASET_PREPARED_PATH\" \
|
205
|
+
--model_name_or_path \"$MODEL_PATH\" \
|
206
|
+
--tokenizer_name_or_path \"$MODEL_PATH\" \
|
207
|
+
--output_dir \"$output_dir\" \
|
208
|
+
--preprocessing_num_workers \"$PREPROCESSING_WORKERS\" \
|
209
|
+
--per_device_eval_batch_size \"$BATCH_SIZE\" \
|
210
|
+
--max_tokens_per_batch \"$MAX_TOKENS_PER_BATCH\" \
|
211
|
+
--torch_type \"$TORCH_TYPE\""
|
212
|
+
|
213
|
+
# Add sample packing flag if not disabled
|
214
|
+
if [ "$DISABLE_SAMPLE_PACKING" = "false" ]; then
|
215
|
+
FEATURE_CMD="$FEATURE_CMD --sample_packing"
|
216
|
+
fi
|
217
|
+
|
218
|
+
# Step 1: Feature extraction
|
219
|
+
log "Starting feature extraction for $cohort_name..."
|
220
|
+
log "Command: $FEATURE_CMD"
|
221
|
+
|
222
|
+
eval "$FEATURE_CMD > \"$cohort_log\" 2>&1"
|
223
|
+
|
224
|
+
feature_extraction_status=$?
|
225
|
+
if [ $feature_extraction_status -ne 0 ]; then
|
226
|
+
log "ERROR: Feature extraction failed for $cohort_name. Check $cohort_log for details."
|
227
|
+
continue
|
228
|
+
fi
|
229
|
+
|
230
|
+
# Step 2: Model training
|
231
|
+
log "Starting model training for $cohort_name..."
|
232
|
+
log "Command: python -u -m cehrgpt.tools.linear_prob.train_with_cehrgpt_features --features_data_dir $output_dir --output_dir $output_dir"
|
233
|
+
|
234
|
+
python -u -m cehrgpt.tools.linear_prob.train_with_cehrgpt_features \
|
235
|
+
--features_data_dir "$output_dir" \
|
236
|
+
--output_dir "$output_dir" \
|
237
|
+
>> "$cohort_log" 2>&1
|
238
|
+
|
239
|
+
echo "Running meds-evaluation for logistic regression for $TASK_NAME..."
|
240
|
+
meds-evaluation-cli predictions_path="$output_dir/logistic/test_predictions" \
|
241
|
+
output_dir="$output_dir/logistic/"
|
242
|
+
|
243
|
+
# Check if the second command succeeded
|
244
|
+
if [ $? -ne 0 ]; then
|
245
|
+
echo "Error: Running meds-evaluation failed for logistic regression for task $TASK_NAME"
|
246
|
+
fi
|
247
|
+
|
248
|
+
model_training_status=$?
|
249
|
+
if [ $model_training_status -ne 0 ]; then
|
250
|
+
log "ERROR: Model training failed for $cohort_name. Check $cohort_log for details."
|
251
|
+
continue
|
252
|
+
fi
|
253
|
+
|
254
|
+
log "Successfully completed processing for $cohort_name"
|
255
|
+
done < "$TEMP_COHORT_LIST"
|
256
|
+
|
257
|
+
# Clean up
|
258
|
+
rm -f "$TEMP_COHORT_LIST"
|
259
|
+
|
260
|
+
log "All processing completed"
|
@@ -0,0 +1,36 @@
|
|
1
|
+
import os
|
2
|
+
|
3
|
+
import polars as pl
|
4
|
+
|
5
|
+
from cehrgpt.gpt_utils import extract_time_interval_in_days, is_att_token
|
6
|
+
|
7
|
+
|
8
|
+
def main(args):
|
9
|
+
dataset = pl.read_parquet(os.path.join(args.input_dir, "*.parquet"))
|
10
|
+
time_token_frequency_df = (
|
11
|
+
dataset.select(pl.col("concept_ids").explode().alias("concept_id"))
|
12
|
+
.filter(pl.col("concept_id").map_elements(is_att_token))
|
13
|
+
.with_columns(
|
14
|
+
pl.col("concept_id")
|
15
|
+
.map_elements(extract_time_interval_in_days)
|
16
|
+
.alias("time_interval")
|
17
|
+
)
|
18
|
+
)
|
19
|
+
results = time_token_frequency_df.select(
|
20
|
+
pl.mean("time_interval").alias("mean"), pl.std("time_interval").alias("std")
|
21
|
+
).to_dicts()[0]
|
22
|
+
print(results)
|
23
|
+
|
24
|
+
|
25
|
+
if __name__ == "__main__":
|
26
|
+
import argparse
|
27
|
+
|
28
|
+
parser = argparse.ArgumentParser(description="EHR Irregularity analysis")
|
29
|
+
parser.add_argument(
|
30
|
+
"--input_dir",
|
31
|
+
dest="input_dir",
|
32
|
+
action="store",
|
33
|
+
help="The path for where the input data folder",
|
34
|
+
required=True,
|
35
|
+
)
|
36
|
+
main(parser.parse_args())
|
@@ -1,9 +1,10 @@
|
|
1
|
-
from typing import Union
|
1
|
+
from typing import Optional, Union
|
2
2
|
|
3
3
|
from cehrbert.data_generators.hf_data_generator.hf_dataset import (
|
4
4
|
FINETUNING_COLUMNS,
|
5
5
|
apply_cehrbert_dataset_mapping,
|
6
6
|
)
|
7
|
+
from cehrbert.data_generators.hf_data_generator.meds_utils import CacheFileCollector
|
7
8
|
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
|
8
9
|
from datasets import Dataset, DatasetDict
|
9
10
|
|
@@ -22,6 +23,7 @@ CEHRGPT_COLUMNS = [
|
|
22
23
|
"num_of_visits",
|
23
24
|
"values",
|
24
25
|
"value_indicators",
|
26
|
+
"epoch_times",
|
25
27
|
]
|
26
28
|
|
27
29
|
TRANSFORMER_COLUMNS = ["input_ids"]
|
@@ -31,16 +33,25 @@ def create_cehrgpt_pretraining_dataset(
|
|
31
33
|
dataset: Union[Dataset, DatasetDict],
|
32
34
|
cehrgpt_tokenizer: CehrGptTokenizer,
|
33
35
|
data_args: DataTrainingArguments,
|
34
|
-
|
36
|
+
cache_file_collector: Optional[CacheFileCollector] = None,
|
37
|
+
) -> Union[Dataset, DatasetDict]:
|
35
38
|
required_columns = TRANSFORMER_COLUMNS + CEHRGPT_COLUMNS
|
39
|
+
# TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
|
40
|
+
if not data_args.streaming:
|
41
|
+
if isinstance(dataset, DatasetDict):
|
42
|
+
all_columns = dataset["train"].column_names
|
43
|
+
else:
|
44
|
+
all_columns = dataset.column_names
|
45
|
+
if "visit_concept_ids" in all_columns:
|
46
|
+
dataset.remove_columns(["visit_concept_ids"])
|
36
47
|
dataset = apply_cehrbert_dataset_mapping(
|
37
48
|
dataset,
|
38
49
|
HFCehrGptTokenizationMapping(cehrgpt_tokenizer),
|
39
50
|
num_proc=data_args.preprocessing_num_workers,
|
40
51
|
batch_size=data_args.preprocessing_batch_size,
|
41
52
|
streaming=data_args.streaming,
|
53
|
+
cache_file_collector=cache_file_collector,
|
42
54
|
)
|
43
|
-
|
44
55
|
if not data_args.streaming:
|
45
56
|
if isinstance(dataset, DatasetDict):
|
46
57
|
all_columns = dataset["train"].column_names
|
@@ -56,8 +67,17 @@ def create_cehrgpt_finetuning_dataset(
|
|
56
67
|
dataset: Union[Dataset, DatasetDict],
|
57
68
|
cehrgpt_tokenizer: CehrGptTokenizer,
|
58
69
|
data_args: DataTrainingArguments,
|
59
|
-
|
70
|
+
cache_file_collector: Optional[CacheFileCollector] = None,
|
71
|
+
) -> Union[Dataset, DatasetDict]:
|
60
72
|
required_columns = TRANSFORMER_COLUMNS + CEHRGPT_COLUMNS + FINETUNING_COLUMNS
|
73
|
+
# TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
|
74
|
+
if not data_args.streaming:
|
75
|
+
if isinstance(dataset, DatasetDict):
|
76
|
+
all_columns = dataset["train"].column_names
|
77
|
+
else:
|
78
|
+
all_columns = dataset.column_names
|
79
|
+
if "visit_concept_ids" in all_columns:
|
80
|
+
dataset.remove_columns(["visit_concept_ids"])
|
61
81
|
mapping_functions = [
|
62
82
|
HFFineTuningMapping(cehrgpt_tokenizer),
|
63
83
|
]
|
@@ -68,6 +88,7 @@ def create_cehrgpt_finetuning_dataset(
|
|
68
88
|
num_proc=data_args.preprocessing_num_workers,
|
69
89
|
batch_size=data_args.preprocessing_batch_size,
|
70
90
|
streaming=data_args.streaming,
|
91
|
+
cache_file_collector=cache_file_collector,
|
71
92
|
)
|
72
93
|
|
73
94
|
if not data_args.streaming:
|