genhpf 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of genhpf might be problematic. Click here for more details.

Files changed (85) hide show
  1. genhpf-1.0.0/.gitignore +159 -0
  2. genhpf-1.0.0/.pre-commit-config.yaml +45 -0
  3. genhpf-1.0.0/LICENSE +21 -0
  4. genhpf-1.0.0/PKG-INFO +197 -0
  5. genhpf-1.0.0/README.md +173 -0
  6. genhpf-1.0.0/examples/pretrain/mlm/genhpf/flattened_pt.yaml +44 -0
  7. genhpf-1.0.0/examples/pretrain/simclr/genhpf/genhpf_hierarchical_pt.yaml +43 -0
  8. genhpf-1.0.0/examples/pretrain/wav2vec2/genhpf/hierarchical_pt.yaml +44 -0
  9. genhpf-1.0.0/examples/test/genhpf/genhpf_flattened.yaml +87 -0
  10. genhpf-1.0.0/examples/test/genhpf/genhpf_hierarchical.yaml +86 -0
  11. genhpf-1.0.0/examples/test/genhpf/meds_hierarchical.yaml +55 -0
  12. genhpf-1.0.0/examples/train/genhpf/genhpf_flattened_ft.yaml +106 -0
  13. genhpf-1.0.0/examples/train/genhpf/genhpf_hierarchical_ft.yaml +105 -0
  14. genhpf-1.0.0/examples/train/genhpf/genhpf_hierarchical_scr.yaml +103 -0
  15. genhpf-1.0.0/examples/train/genhpf/meds_hierarchical_scr.yaml +50 -0
  16. genhpf-1.0.0/pyproject.toml +55 -0
  17. genhpf-1.0.0/requirements.txt +48 -0
  18. genhpf-1.0.0/setup.cfg +4 -0
  19. genhpf-1.0.0/src/genhpf/__init__.py +9 -0
  20. genhpf-1.0.0/src/genhpf/configs/__init__.py +23 -0
  21. genhpf-1.0.0/src/genhpf/configs/config.yaml +8 -0
  22. genhpf-1.0.0/src/genhpf/configs/configs.py +240 -0
  23. genhpf-1.0.0/src/genhpf/configs/constants.py +29 -0
  24. genhpf-1.0.0/src/genhpf/configs/initialize.py +58 -0
  25. genhpf-1.0.0/src/genhpf/configs/utils.py +29 -0
  26. genhpf-1.0.0/src/genhpf/criterions/__init__.py +74 -0
  27. genhpf-1.0.0/src/genhpf/criterions/binary_cross_entropy.py +114 -0
  28. genhpf-1.0.0/src/genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
  29. genhpf-1.0.0/src/genhpf/criterions/criterion.py +87 -0
  30. genhpf-1.0.0/src/genhpf/criterions/cross_entropy.py +202 -0
  31. genhpf-1.0.0/src/genhpf/criterions/multi_task_criterion.py +177 -0
  32. genhpf-1.0.0/src/genhpf/criterions/simclr_criterion.py +84 -0
  33. genhpf-1.0.0/src/genhpf/criterions/wav2vec2_criterion.py +130 -0
  34. genhpf-1.0.0/src/genhpf/datasets/__init__.py +84 -0
  35. genhpf-1.0.0/src/genhpf/datasets/dataset.py +109 -0
  36. genhpf-1.0.0/src/genhpf/datasets/genhpf_dataset.py +451 -0
  37. genhpf-1.0.0/src/genhpf/datasets/meds_dataset.py +232 -0
  38. genhpf-1.0.0/src/genhpf/loggings/__init__.py +0 -0
  39. genhpf-1.0.0/src/genhpf/loggings/meters.py +374 -0
  40. genhpf-1.0.0/src/genhpf/loggings/metrics.py +155 -0
  41. genhpf-1.0.0/src/genhpf/loggings/progress_bar.py +445 -0
  42. genhpf-1.0.0/src/genhpf/models/__init__.py +73 -0
  43. genhpf-1.0.0/src/genhpf/models/genhpf.py +233 -0
  44. genhpf-1.0.0/src/genhpf/models/genhpf_mlm.py +64 -0
  45. genhpf-1.0.0/src/genhpf/models/genhpf_predictor.py +73 -0
  46. genhpf-1.0.0/src/genhpf/models/genhpf_simclr.py +58 -0
  47. genhpf-1.0.0/src/genhpf/models/genhpf_wav2vec2.py +304 -0
  48. genhpf-1.0.0/src/genhpf/modules/__init__.py +15 -0
  49. genhpf-1.0.0/src/genhpf/modules/gather_layer.py +23 -0
  50. genhpf-1.0.0/src/genhpf/modules/grad_multiply.py +12 -0
  51. genhpf-1.0.0/src/genhpf/modules/gumbel_vector_quantizer.py +204 -0
  52. genhpf-1.0.0/src/genhpf/modules/identity_layer.py +8 -0
  53. genhpf-1.0.0/src/genhpf/modules/layer_norm.py +27 -0
  54. genhpf-1.0.0/src/genhpf/modules/positional_encoding.py +24 -0
  55. genhpf-1.0.0/src/genhpf/scripts/__init__.py +0 -0
  56. genhpf-1.0.0/src/genhpf/scripts/preprocess/__init__.py +0 -0
  57. genhpf-1.0.0/src/genhpf/scripts/preprocess/genhpf/README.md +75 -0
  58. genhpf-1.0.0/src/genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
  59. genhpf-1.0.0/src/genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
  60. genhpf-1.0.0/src/genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
  61. genhpf-1.0.0/src/genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
  62. genhpf-1.0.0/src/genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
  63. genhpf-1.0.0/src/genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
  64. genhpf-1.0.0/src/genhpf/scripts/preprocess/genhpf/main.py +174 -0
  65. genhpf-1.0.0/src/genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
  66. genhpf-1.0.0/src/genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
  67. genhpf-1.0.0/src/genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
  68. genhpf-1.0.0/src/genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
  69. genhpf-1.0.0/src/genhpf/scripts/preprocess/manifest.py +83 -0
  70. genhpf-1.0.0/src/genhpf/scripts/preprocess/preprocess_meds.py +584 -0
  71. genhpf-1.0.0/src/genhpf/scripts/test.py +261 -0
  72. genhpf-1.0.0/src/genhpf/scripts/train.py +350 -0
  73. genhpf-1.0.0/src/genhpf/trainer.py +370 -0
  74. genhpf-1.0.0/src/genhpf/utils/checkpoint_utils.py +171 -0
  75. genhpf-1.0.0/src/genhpf/utils/data_utils.py +130 -0
  76. genhpf-1.0.0/src/genhpf/utils/distributed_utils.py +497 -0
  77. genhpf-1.0.0/src/genhpf/utils/file_io.py +170 -0
  78. genhpf-1.0.0/src/genhpf/utils/pdb.py +38 -0
  79. genhpf-1.0.0/src/genhpf/utils/utils.py +204 -0
  80. genhpf-1.0.0/src/genhpf.egg-info/PKG-INFO +197 -0
  81. genhpf-1.0.0/src/genhpf.egg-info/SOURCES.txt +83 -0
  82. genhpf-1.0.0/src/genhpf.egg-info/dependency_links.txt +1 -0
  83. genhpf-1.0.0/src/genhpf.egg-info/entry_points.txt +6 -0
  84. genhpf-1.0.0/src/genhpf.egg-info/requires.txt +12 -0
  85. genhpf-1.0.0/src/genhpf.egg-info/top_level.txt +1 -0
@@ -0,0 +1,159 @@
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+ *fatsta
131
+ .pkl
132
+ *.pkl
133
+ __pycache__/
134
+ checkpoints/
135
+ results/
136
+ .vscode/
137
+ wandb/
138
+ .wandb/
139
+ *.pt
140
+ *.pdf
141
+ .pdf
142
+ *.csv
143
+ .csv
144
+
145
+ # mkdocs documentation
146
+ /site
147
+
148
+ # mypy
149
+ .mypy_cache/
150
+ .dmypy.json
151
+ dmypy.json
152
+
153
+ # Experimental Folder
154
+ outputs/*
155
+ outputs
156
+
157
+ # Weights and Biases logs
158
+ wandb/
159
+
@@ -0,0 +1,45 @@
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v5.0.0
4
+ hooks:
5
+ - id: check-yaml
6
+ - id: end-of-file-fixer
7
+ - id: trailing-whitespace
8
+
9
+ # python code formatting
10
+ - repo: https://github.com/psf/black
11
+ rev: 23.7.0
12
+ hooks:
13
+ - id: black
14
+ args: [--line-length, "110"]
15
+
16
+ # python import sorting
17
+ - repo: https://github.com/PyCQA/isort
18
+ rev: 5.12.0
19
+ hooks:
20
+ - id: isort
21
+ args: ["--profile", "black", "--filter-files", "-o", "wandb"]
22
+
23
+ # python check (PEP8), programming errors and code complexity
24
+ - repo: https://github.com/PyCQA/flake8
25
+ rev: 6.1.0
26
+ hooks:
27
+ - id: flake8
28
+ args:
29
+ [
30
+ "--max-complexity=10",
31
+ "--extend-ignore",
32
+ "E402,E701,E251,E226,E302,W504,E704,E402,E401,C901,E203",
33
+ "--max-line-length=110",
34
+ "--exclude",
35
+ "logs/*,data/*",
36
+ "--per-file-ignores",
37
+ "__init__.py:F401",
38
+ ]
39
+
40
+ # yaml formatting
41
+ - repo: https://github.com/pre-commit/mirrors-prettier
42
+ rev: v3.0.3
43
+ hooks:
44
+ - id: prettier
45
+ types: [yaml]
genhpf-1.0.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2022 hoon9405
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
genhpf-1.0.0/PKG-INFO ADDED
@@ -0,0 +1,197 @@
1
+ Metadata-Version: 2.2
2
+ Name: genhpf
3
+ Version: 1.0.0
4
+ Summary: GenHPF: General Healthcare Predictive Framework with Multi-task Multi-source Learning
5
+ Author-email: Jungwoo Oh <ojw0123@kaist.ac.kr>, Kyunghoon Hur <pacesun@kaist.ac.kr>
6
+ License: MIT license
7
+ Classifier: Intended Audience :: Science/Research
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
10
+ Requires-Python: >=3.10.0
11
+ Description-Content-Type: text/markdown
12
+ License-File: LICENSE
13
+ Requires-Dist: hydra-core==1.3.2
14
+ Requires-Dist: omegaconf==2.3.0
15
+ Requires-Dist: torch==2.6.0
16
+ Requires-Dist: transformers==4.49.0
17
+ Requires-Dist: h5pickle==0.4.2
18
+ Requires-Dist: scikit-learn==1.6.1
19
+ Requires-Dist: pandas==2.2.3
20
+ Requires-Dist: polars==1.17.1
21
+ Provides-Extra: dev
22
+ Requires-Dist: pre-commit; extra == "dev"
23
+ Requires-Dist: black; extra == "dev"
24
+
25
+ # GenHPF : General Healthcare Predictive Framework for Multi-task Multi-source Learning
26
+
27
+ GenHPF is a general healthcare predictive framework, which requires no medical domain knowledge and minimal preprocessing for multiple prediction tasks.
28
+
29
+ Our framework presents a method for embedding any form of EHR systems for prediction tasks without requiring domain-knowledge-based pre-processing, such as medical code mapping and feature selection.
30
+
31
+ This repository provides official Pytorch code to implement GenHPF, a general healthcare predictive framework.
32
+
33
+ # Getting started with GenHPF
34
+ ## STEP 1: Installation
35
+ For developing locally:
36
+ ```bash
37
+ $ pip install -e ./
38
+ ```
39
+
40
+ Otherwise:
41
+ ```bash
42
+ $ pip install genhpf
43
+ ```
44
+
45
+ ## STEP 2: Prepare training data
46
+ ### Preprocessing raw datasets to reproduce GenHPF paper results (GenHPF dataset)
47
+ Download raw datasets and required tools:
48
+ * [MIMIC-III](https://physionet.org/content/mimiciii/1.4/)
49
+ * [MIMIC-IV](https://physionet.org/content/mimiciv/2.0/)
50
+ * [eICU](https://physionet.org/content/eicu-crd/2.0/)
51
+
52
+ Then, run:
53
+ ```bash
54
+ genhpf-preprocess \
55
+ --data $DATA_DIR \
56
+ --ehr {"eicu", "mimiciii", "mimiciv"} \
57
+ --dest $OUTPUT_DIR \
58
+ --first_icu \
59
+ --emb_type {"textbase", "codebase"} \
60
+ --feature {"all_features", "select"} \
61
+ --mortality \
62
+ --long_term_mortality \
63
+ ... # add desired prediction tasks
64
+ ```
65
+ This will output the processed data (`data.h5` and `label.csv`) into `$DATA_DIR/data/` directory.
66
+ For detailed descriptions for each argument, see [src/genhpf/scripts/preprocess/genhpf/README.md](src/genhpf/scripts/preprocess/genhpf/README.md).
67
+ <!-- Note that pre-processing takes about 6hours in 128 cores of AMD EPYC 7502 32-Core Processor, and requires 180GB of RAM. -->
68
+
69
+ Finally, you should prepare data manifest based on the preprocessed data:
70
+ ```bash
71
+ genhpf-manifest $data_dir $label_dir \
72
+ --dest=$output_dir \
73
+ --prefix=$prefix \
74
+ --valid_percent=$valid_percent
75
+ ```
76
+ This will generate the manifest files (e.g., `$prefix_train.tsv`, `$prefix_valid.tsv`, `$prefix_test.tsv`) to `$output_dir` based on the `$data_dir`, which contains `data.h5`, and `$label_dir`, which contains `label.csv`.
77
+ The ratio among train, valid, and test splits is decided by `$valid_percent`.
78
+ Note that this is useful to handle various concepts of training and test datasets.
79
+ For instance, if you want to use multiple datasets (e.g., mimiciv and eicu) for training and evaluate the model on each of the datasets separately, you can perform it by placing the corresponding manifest files (e.g., mimiciv_train, eicu_train, mimiciv_valid, eicu_valid, mimiciv_test, eicu_test) in the same data directory and specifying the following command-line arguments: `dataset.train_subset="mimiciv_train,eicu_train" dataset.combine_train_subsets=true dataset.valid_subset="mimiciv_valid,eicu_valid" dataset.test_subset="mimiciv_test,eicu_test"`.
80
+
81
+ ### Preprocessing MEDS dataset
82
+ We also provide a script to preprocess [MEDS](https://github.com/mmcdermott/MEDS-DEV) dataset with a cohort defined by [ACES](https://github.com/justin13601/ACES) or [MEDS-DEV](https://github.com/mmcdermott/MEDS-DEV) (see Task section) to run with GenHPF.
83
+
84
+ ```bash
85
+ genhpf-preprocess-meds $MEDS_DATA_DIR \
86
+ --cohort $MEDS_LABELS_DIR \
87
+ --metadata_dir $MEDS_METADATA_DIR \
88
+ --output_dir $MEDS_OUTPUT_DIR \
89
+ --workers $NUM_WORKERS
90
+ ```
91
+
92
+ * `$MEDS_DATA_DIR`: a path to the data directory containing MEDS data to be processed. It can be a directory or the exact file path with the file extension (only `.csv` or `.parquet` allowed). If provided with directory, it tries to scan all `*.csv` or `*.parquet` files contained in the directory recursively. See [this](https://github.com/mmcdermott/MEDS-DEV?tab=readme-ov-file#building-a-dataset) if you want to build a new MEDS dataset based on MIMIC-III, MIMIC-IV, and eICU.
93
+ * `$MEDS_LABELS_DIR`: a path to the label directory for a given task, which must be a result of [ACES](https://github.com/justin13601/ACES) or [MEDS-DEV](https://github.com/mmcdermott/MEDS-DEV). It can be a directory or the exact file path that has the same file extension with the MEDS dataset to be processed. The file structure of this cohort directory should be the same with the provided MEDS data directory (`$MEDS_DATA_DIR`) to match each cohort to its corresponding shard data. See [this](https://github.com/mmcdermott/MEDS-DEV?tab=readme-ov-file#extracting-a-task) to extract a cohort for a specific task defined in MEDS-DEV.
94
+ * `$MEDS_METADATA_DIR`: a path to the metadata directory for the input MEDS dataset, expected to contain `codes.parquet`. This is used to retrieve descriptions for codes in MEDS events and convert each code to the retrieved description. Note that if a code has no specific description in `codes.parquet`, it will just treat that code as a plain text and process the event as it is.
95
+ * `$MEDS_OUTPUT_DIR`: directory to save processed outputs.
96
+ * Enabling `--rebase` will renew this directory.
97
+ * `$NUM_WORKERS`: number of parallel workers to multi-process the script.
98
+ * **NOTE: if you encounter this error: _"polars' maximum length reached. consider installing 'polars-u64-idx'"_, please consider using more workers or installing polars-u64-idx by `pip install polars-u64-idx`.**
99
+
100
+ As a result, you will have `.h5` and `.tsv` files that has a following respective structure:
101
+ * `*.h5`
102
+ ```
103
+ *.h5
104
+ └── ${cohort_id}
105
+ └── "ehr"
106
+ ├── "hi"
107
+ │ └── np.ndarray with a shape of (num_events, 3, max_length)
108
+ ├── "time"
109
+ │ └── np.ndarray with a shape of (num_events, )
110
+ └── "label"
111
+ └── binary label (0 or 1) for ${cohort_id} given the defined task
112
+ ```
113
+ * `${cohord_id}`: `${patient_id}_${cohort_number}`, standing for **N-th cohort in the patient**.
114
+ * Numpy array under `"hi"`
115
+ * `[:, 0, :]`: token input ids (i.e., `input_ids`) for the tokenized events.
116
+ * `[:, 1, :]`: token type ids (i.e., `type_ids`) to distinguish where each input token comes from (special tokens such as `[CLS]` or `[SEP]`, column keys, or column values).
117
+ * `[:, 2, :]`: tokens indicting digit places for number type tokens (i.e., `dpe_ids`). It assigns different ids to each of digit places for numeric (integer or float) items.
118
+ * Numpy array under `"time"
119
+ * Elapsed time in minutes from the first event to the last event. We do not this feature currently, but reserve it for future usage (e.g., developing a method to embed events with their temporal features).
120
+ * `*.tsv`
121
+ ```
122
+ patient_id num_events
123
+ 0 10001472_0 13
124
+ 1 10002013_0 47
125
+ 2 10002013_1 46
126
+ … … …
127
+ ```
128
+
129
+ ## STEP 3: Training a new model
130
+ We prepared example configuration files for various models and experimental setups.
131
+ For detailed configurations, please see [configs.py](src/genhpf/configs/configs.py) and each implemented source code (e.g., [genhpf.py](src/genhpf/models/genhpf.py)).
132
+
133
+ ### Examples to process GenHPF dataset
134
+ ### Train a new GenHPF model from scratch:
135
+ ```bash
136
+ genhpf-train \
137
+ dataset.data=??? \
138
+ --config-dir ${GENHPF_DIR}/examples/train/genhpf \
139
+ --config-name genhpf_hierarchical_scr
140
+ ```
141
+ Note that you should fill in `dataset.data=???` with a path to the directory that contains the data manifest files (e.g., `train.tsv`, `valid.tsv`, etc.) for the processed GenHPF data.
142
+
143
+ ### Pre-train and fine-tune a new GenHPF model:
144
+ For pre-training with SimCLR:
145
+ ```bash
146
+ genhpf-train \
147
+ dataset.data=??? \
148
+ --config-dir ${GENHPF_DIR}/examples/pretrain/simclr/genhpf \
149
+ --config-name genhpf_hierarchical_pt
150
+ ```
151
+ For fine-tuning:
152
+ ```bash
153
+ genhpf-train \
154
+ dataset.data=??? \
155
+ model.from_pretrained=${/path/to/the/pretrained/checkpoint.pt} \
156
+ --config-dir ${GENHPF_DIR}/examples/train/genhpf \
157
+ --config-name genhpf_hierarchical_ft
158
+ ```
159
+
160
+ ### Examples to process MEDS dataset
161
+ ```bash
162
+ genhpf-train \
163
+ dataset.data=??? \
164
+ --config-dir ${GENHPF_DIR}/examples/train/genhpf \
165
+ --config-name meds_hierarchical_scr
166
+ ```
167
+ Note that you should fill in `dataset.data=???` with a path to the directory that contains the data manifest files (e.g., `train.tsv`, `tuning.tsv`, etc.) for the processed MEDS data (i.e., `$MEDS_OUTPUT_DIR`).
168
+
169
+ For doing inference on MEDS dataset while outputting prediction results to evaluate the model using [meds-evaluation](https://github.com/kamilest/meds-evaluation):
170
+ ```bash
171
+ genhpf-test \
172
+ dataset.data=??? \
173
+ meds.output_predictions=true \
174
+ meds.labels_dir=$MEDS_LABELS_DIR \
175
+ meds.output_dir=$OUTPUT_DIR \
176
+ checkpoint.load_checkpoint=${/path/to/the/trained/checkpoint.pt} \
177
+ --config-dir ${GENHPF_DIR}/examples/test/genhpf \
178
+ --config-name meds_hierarchical
179
+ ```
180
+ This script will load the model weights from `${/path/to/the/trained/checkpoint.pt}`, process the data specified by `dataset.data`, and output the prediction results for the test subset as a single parquet file to `$OUTPUT_DIR` directory.
181
+ Note that the data directory `dataset.data` should contain the directory for the test data with its manifest file (e.g., `held_out/*.h5` with `held_out.tsv`), where the name of the test subset is specified by `dataset.test_subset` config.
182
+
183
+ ## Citation
184
+ If you find GenHPF useful for your research and applications, please cite using this BibTeX:
185
+ ```bibtex
186
+
187
+ @article{hur2023genhpf,
188
+ title={GenHPF: General Healthcare Predictive Framework for Multi-task Multi-source Learning},
189
+ author={Hur, Kyunghoon and Oh, Jungwoo and Kim, Junu and Kim, Jiyoun and Lee, Min Jae and Cho, Eunbyeol and Moon, Seong-Eun and Kim, Young-Hak and Atallah, Louis and Choi, Edward},
190
+ journal={IEEE Journal of Biomedical and Health Informatics},
191
+ year={2023},
192
+ publisher={IEEE}
193
+ }
194
+ ```
195
+
196
+ # License
197
+ This repository is MIT-lincensed.
genhpf-1.0.0/README.md ADDED
@@ -0,0 +1,173 @@
1
+ # GenHPF : General Healthcare Predictive Framework for Multi-task Multi-source Learning
2
+
3
+ GenHPF is a general healthcare predictive framework, which requires no medical domain knowledge and minimal preprocessing for multiple prediction tasks.
4
+
5
+ Our framework presents a method for embedding any form of EHR systems for prediction tasks without requiring domain-knowledge-based pre-processing, such as medical code mapping and feature selection.
6
+
7
+ This repository provides official Pytorch code to implement GenHPF, a general healthcare predictive framework.
8
+
9
+ # Getting started with GenHPF
10
+ ## STEP 1: Installation
11
+ For developing locally:
12
+ ```bash
13
+ $ pip install -e ./
14
+ ```
15
+
16
+ Otherwise:
17
+ ```bash
18
+ $ pip install genhpf
19
+ ```
20
+
21
+ ## STEP 2: Prepare training data
22
+ ### Preprocessing raw datasets to reproduce GenHPF paper results (GenHPF dataset)
23
+ Download raw datasets and required tools:
24
+ * [MIMIC-III](https://physionet.org/content/mimiciii/1.4/)
25
+ * [MIMIC-IV](https://physionet.org/content/mimiciv/2.0/)
26
+ * [eICU](https://physionet.org/content/eicu-crd/2.0/)
27
+
28
+ Then, run:
29
+ ```bash
30
+ genhpf-preprocess \
31
+ --data $DATA_DIR \
32
+ --ehr {"eicu", "mimiciii", "mimiciv"} \
33
+ --dest $OUTPUT_DIR \
34
+ --first_icu \
35
+ --emb_type {"textbase", "codebase"} \
36
+ --feature {"all_features", "select"} \
37
+ --mortality \
38
+ --long_term_mortality \
39
+ ... # add desired prediction tasks
40
+ ```
41
+ This will output the processed data (`data.h5` and `label.csv`) into `$DATA_DIR/data/` directory.
42
+ For detailed descriptions for each argument, see [src/genhpf/scripts/preprocess/genhpf/README.md](src/genhpf/scripts/preprocess/genhpf/README.md).
43
+ <!-- Note that pre-processing takes about 6hours in 128 cores of AMD EPYC 7502 32-Core Processor, and requires 180GB of RAM. -->
44
+
45
+ Finally, you should prepare data manifest based on the preprocessed data:
46
+ ```bash
47
+ genhpf-manifest $data_dir $label_dir \
48
+ --dest=$output_dir \
49
+ --prefix=$prefix \
50
+ --valid_percent=$valid_percent
51
+ ```
52
+ This will generate the manifest files (e.g., `$prefix_train.tsv`, `$prefix_valid.tsv`, `$prefix_test.tsv`) to `$output_dir` based on the `$data_dir`, which contains `data.h5`, and `$label_dir`, which contains `label.csv`.
53
+ The ratio among train, valid, and test splits is decided by `$valid_percent`.
54
+ Note that this is useful to handle various concepts of training and test datasets.
55
+ For instance, if you want to use multiple datasets (e.g., mimiciv and eicu) for training and evaluate the model on each of the datasets separately, you can perform it by placing the corresponding manifest files (e.g., mimiciv_train, eicu_train, mimiciv_valid, eicu_valid, mimiciv_test, eicu_test) in the same data directory and specifying the following command-line arguments: `dataset.train_subset="mimiciv_train,eicu_train" dataset.combine_train_subsets=true dataset.valid_subset="mimiciv_valid,eicu_valid" dataset.test_subset="mimiciv_test,eicu_test"`.
56
+
57
+ ### Preprocessing MEDS dataset
58
+ We also provide a script to preprocess [MEDS](https://github.com/mmcdermott/MEDS-DEV) dataset with a cohort defined by [ACES](https://github.com/justin13601/ACES) or [MEDS-DEV](https://github.com/mmcdermott/MEDS-DEV) (see Task section) to run with GenHPF.
59
+
60
+ ```bash
61
+ genhpf-preprocess-meds $MEDS_DATA_DIR \
62
+ --cohort $MEDS_LABELS_DIR \
63
+ --metadata_dir $MEDS_METADATA_DIR \
64
+ --output_dir $MEDS_OUTPUT_DIR \
65
+ --workers $NUM_WORKERS
66
+ ```
67
+
68
+ * `$MEDS_DATA_DIR`: a path to the data directory containing MEDS data to be processed. It can be a directory or the exact file path with the file extension (only `.csv` or `.parquet` allowed). If provided with directory, it tries to scan all `*.csv` or `*.parquet` files contained in the directory recursively. See [this](https://github.com/mmcdermott/MEDS-DEV?tab=readme-ov-file#building-a-dataset) if you want to build a new MEDS dataset based on MIMIC-III, MIMIC-IV, and eICU.
69
+ * `$MEDS_LABELS_DIR`: a path to the label directory for a given task, which must be a result of [ACES](https://github.com/justin13601/ACES) or [MEDS-DEV](https://github.com/mmcdermott/MEDS-DEV). It can be a directory or the exact file path that has the same file extension with the MEDS dataset to be processed. The file structure of this cohort directory should be the same with the provided MEDS data directory (`$MEDS_DATA_DIR`) to match each cohort to its corresponding shard data. See [this](https://github.com/mmcdermott/MEDS-DEV?tab=readme-ov-file#extracting-a-task) to extract a cohort for a specific task defined in MEDS-DEV.
70
+ * `$MEDS_METADATA_DIR`: a path to the metadata directory for the input MEDS dataset, expected to contain `codes.parquet`. This is used to retrieve descriptions for codes in MEDS events and convert each code to the retrieved description. Note that if a code has no specific description in `codes.parquet`, it will just treat that code as a plain text and process the event as it is.
71
+ * `$MEDS_OUTPUT_DIR`: directory to save processed outputs.
72
+ * Enabling `--rebase` will renew this directory.
73
+ * `$NUM_WORKERS`: number of parallel workers to multi-process the script.
74
+ * **NOTE: if you encounter this error: _"polars' maximum length reached. consider installing 'polars-u64-idx'"_, please consider using more workers or installing polars-u64-idx by `pip install polars-u64-idx`.**
75
+
76
+ As a result, you will have `.h5` and `.tsv` files that has a following respective structure:
77
+ * `*.h5`
78
+ ```
79
+ *.h5
80
+ └── ${cohort_id}
81
+ └── "ehr"
82
+ ├── "hi"
83
+ │ └── np.ndarray with a shape of (num_events, 3, max_length)
84
+ ├── "time"
85
+ │ └── np.ndarray with a shape of (num_events, )
86
+ └── "label"
87
+ └── binary label (0 or 1) for ${cohort_id} given the defined task
88
+ ```
89
+ * `${cohord_id}`: `${patient_id}_${cohort_number}`, standing for **N-th cohort in the patient**.
90
+ * Numpy array under `"hi"`
91
+ * `[:, 0, :]`: token input ids (i.e., `input_ids`) for the tokenized events.
92
+ * `[:, 1, :]`: token type ids (i.e., `type_ids`) to distinguish where each input token comes from (special tokens such as `[CLS]` or `[SEP]`, column keys, or column values).
93
+ * `[:, 2, :]`: tokens indicting digit places for number type tokens (i.e., `dpe_ids`). It assigns different ids to each of digit places for numeric (integer or float) items.
94
+ * Numpy array under `"time"
95
+ * Elapsed time in minutes from the first event to the last event. We do not this feature currently, but reserve it for future usage (e.g., developing a method to embed events with their temporal features).
96
+ * `*.tsv`
97
+ ```
98
+ patient_id num_events
99
+ 0 10001472_0 13
100
+ 1 10002013_0 47
101
+ 2 10002013_1 46
102
+ … … …
103
+ ```
104
+
105
+ ## STEP 3: Training a new model
106
+ We prepared example configuration files for various models and experimental setups.
107
+ For detailed configurations, please see [configs.py](src/genhpf/configs/configs.py) and each implemented source code (e.g., [genhpf.py](src/genhpf/models/genhpf.py)).
108
+
109
+ ### Examples to process GenHPF dataset
110
+ ### Train a new GenHPF model from scratch:
111
+ ```bash
112
+ genhpf-train \
113
+ dataset.data=??? \
114
+ --config-dir ${GENHPF_DIR}/examples/train/genhpf \
115
+ --config-name genhpf_hierarchical_scr
116
+ ```
117
+ Note that you should fill in `dataset.data=???` with a path to the directory that contains the data manifest files (e.g., `train.tsv`, `valid.tsv`, etc.) for the processed GenHPF data.
118
+
119
+ ### Pre-train and fine-tune a new GenHPF model:
120
+ For pre-training with SimCLR:
121
+ ```bash
122
+ genhpf-train \
123
+ dataset.data=??? \
124
+ --config-dir ${GENHPF_DIR}/examples/pretrain/simclr/genhpf \
125
+ --config-name genhpf_hierarchical_pt
126
+ ```
127
+ For fine-tuning:
128
+ ```bash
129
+ genhpf-train \
130
+ dataset.data=??? \
131
+ model.from_pretrained=${/path/to/the/pretrained/checkpoint.pt} \
132
+ --config-dir ${GENHPF_DIR}/examples/train/genhpf \
133
+ --config-name genhpf_hierarchical_ft
134
+ ```
135
+
136
+ ### Examples to process MEDS dataset
137
+ ```bash
138
+ genhpf-train \
139
+ dataset.data=??? \
140
+ --config-dir ${GENHPF_DIR}/examples/train/genhpf \
141
+ --config-name meds_hierarchical_scr
142
+ ```
143
+ Note that you should fill in `dataset.data=???` with a path to the directory that contains the data manifest files (e.g., `train.tsv`, `tuning.tsv`, etc.) for the processed MEDS data (i.e., `$MEDS_OUTPUT_DIR`).
144
+
145
+ For doing inference on MEDS dataset while outputting prediction results to evaluate the model using [meds-evaluation](https://github.com/kamilest/meds-evaluation):
146
+ ```bash
147
+ genhpf-test \
148
+ dataset.data=??? \
149
+ meds.output_predictions=true \
150
+ meds.labels_dir=$MEDS_LABELS_DIR \
151
+ meds.output_dir=$OUTPUT_DIR \
152
+ checkpoint.load_checkpoint=${/path/to/the/trained/checkpoint.pt} \
153
+ --config-dir ${GENHPF_DIR}/examples/test/genhpf \
154
+ --config-name meds_hierarchical
155
+ ```
156
+ This script will load the model weights from `${/path/to/the/trained/checkpoint.pt}`, process the data specified by `dataset.data`, and output the prediction results for the test subset as a single parquet file to `$OUTPUT_DIR` directory.
157
+ Note that the data directory `dataset.data` should contain the directory for the test data with its manifest file (e.g., `held_out/*.h5` with `held_out.tsv`), where the name of the test subset is specified by `dataset.test_subset` config.
158
+
159
+ ## Citation
160
+ If you find GenHPF useful for your research and applications, please cite using this BibTeX:
161
+ ```bibtex
162
+
163
+ @article{hur2023genhpf,
164
+ title={GenHPF: General Healthcare Predictive Framework for Multi-task Multi-source Learning},
165
+ author={Hur, Kyunghoon and Oh, Jungwoo and Kim, Junu and Kim, Jiyoun and Lee, Min Jae and Cho, Eunbyeol and Moon, Seong-Eun and Kim, Young-Hak and Atallah, Louis and Choi, Edward},
166
+ journal={IEEE Journal of Biomedical and Health Informatics},
167
+ year={2023},
168
+ publisher={IEEE}
169
+ }
170
+ ```
171
+
172
+ # License
173
+ This repository is MIT-lincensed.
@@ -0,0 +1,44 @@
1
+ common:
2
+ log_format: tqdm
3
+ log_interval: 10
4
+ all_gather_list_size: 2048000
5
+ seed: 42
6
+
7
+ checkpoint:
8
+ save_dir: checkpoints
9
+ save_interval: 1
10
+ keep_last_epochs: 5
11
+
12
+ dataset:
13
+ data_format: genhpf
14
+ data: ???
15
+ label: false
16
+ vocab_size: 28996
17
+ pad_token_id: 0
18
+ sep_token_id: 102
19
+ dummy_token_id: 101
20
+ ignore_index: -100
21
+
22
+ num_workers: 6
23
+ batch_size: 4
24
+ train_subset: "train"
25
+ valid_subset: "valid"
26
+ test_subset: "test"
27
+
28
+ distributed_training:
29
+ distributed_world_size: 1
30
+ find_unused_parameters: False
31
+
32
+ criterion:
33
+ _name: cross_entropy
34
+
35
+ optimization:
36
+ max_epoch: 200
37
+ lr: 1e-4
38
+
39
+ model:
40
+ _name: genhpf_mlm
41
+
42
+ structure: flattened
43
+ embedding_method: text
44
+ agg_max_seq_len: 8192
@@ -0,0 +1,43 @@
1
+ common:
2
+ log_format: tqdm
3
+ log_interval: 10
4
+ all_gather_list_size: 2048000
5
+ seed: 42
6
+
7
+ checkpoint:
8
+ save_dir: checkpoints
9
+ save_interval: 1
10
+ keep_last_epochs: 5
11
+
12
+ dataset:
13
+ data_format: genhpf
14
+ data: ???
15
+ label: false
16
+ vocab_size: 28996
17
+ pad_token_id: 0
18
+ sep_token_id: 102
19
+ dummy_token_id: 101
20
+ ignore_index: -100
21
+
22
+ num_workers: 6
23
+ batch_size: 8
24
+ train_subset: "train"
25
+ valid_subset: "valid"
26
+ test_subset: "test"
27
+
28
+ distributed_training:
29
+ distributed_world_size: 1
30
+ find_unused_parameters: False
31
+
32
+ criterion:
33
+ _name: simclr_criterion
34
+
35
+ optimization:
36
+ max_epoch: 200
37
+ lr: 1e-4
38
+
39
+ model:
40
+ _name: genhpf_simclr
41
+
42
+ structure: hierarchical
43
+ embedding_method: text