genhpf 1.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. genhpf/__init__.py +9 -0
  2. genhpf/configs/__init__.py +23 -0
  3. genhpf/configs/config.yaml +8 -0
  4. genhpf/configs/configs.py +240 -0
  5. genhpf/configs/constants.py +29 -0
  6. genhpf/configs/initialize.py +58 -0
  7. genhpf/configs/utils.py +29 -0
  8. genhpf/criterions/__init__.py +74 -0
  9. genhpf/criterions/binary_cross_entropy.py +114 -0
  10. genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
  11. genhpf/criterions/criterion.py +87 -0
  12. genhpf/criterions/cross_entropy.py +202 -0
  13. genhpf/criterions/multi_task_criterion.py +177 -0
  14. genhpf/criterions/simclr_criterion.py +84 -0
  15. genhpf/criterions/wav2vec2_criterion.py +130 -0
  16. genhpf/datasets/__init__.py +84 -0
  17. genhpf/datasets/dataset.py +109 -0
  18. genhpf/datasets/genhpf_dataset.py +451 -0
  19. genhpf/datasets/meds_dataset.py +232 -0
  20. genhpf/loggings/__init__.py +0 -0
  21. genhpf/loggings/meters.py +374 -0
  22. genhpf/loggings/metrics.py +155 -0
  23. genhpf/loggings/progress_bar.py +445 -0
  24. genhpf/models/__init__.py +73 -0
  25. genhpf/models/genhpf.py +244 -0
  26. genhpf/models/genhpf_mlm.py +64 -0
  27. genhpf/models/genhpf_predictor.py +73 -0
  28. genhpf/models/genhpf_simclr.py +58 -0
  29. genhpf/models/genhpf_wav2vec2.py +304 -0
  30. genhpf/modules/__init__.py +15 -0
  31. genhpf/modules/gather_layer.py +23 -0
  32. genhpf/modules/grad_multiply.py +12 -0
  33. genhpf/modules/gumbel_vector_quantizer.py +204 -0
  34. genhpf/modules/identity_layer.py +8 -0
  35. genhpf/modules/layer_norm.py +27 -0
  36. genhpf/modules/positional_encoding.py +24 -0
  37. genhpf/scripts/__init__.py +0 -0
  38. genhpf/scripts/preprocess/__init__.py +0 -0
  39. genhpf/scripts/preprocess/genhpf/README.md +75 -0
  40. genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
  41. genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
  42. genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
  43. genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
  44. genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
  45. genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
  46. genhpf/scripts/preprocess/genhpf/main.py +175 -0
  47. genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
  48. genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
  49. genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
  50. genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
  51. genhpf/scripts/preprocess/manifest.py +83 -0
  52. genhpf/scripts/preprocess/preprocess_meds.py +674 -0
  53. genhpf/scripts/test.py +264 -0
  54. genhpf/scripts/train.py +365 -0
  55. genhpf/trainer.py +370 -0
  56. genhpf/utils/checkpoint_utils.py +171 -0
  57. genhpf/utils/data_utils.py +130 -0
  58. genhpf/utils/distributed_utils.py +497 -0
  59. genhpf/utils/file_io.py +170 -0
  60. genhpf/utils/pdb.py +38 -0
  61. genhpf/utils/utils.py +204 -0
  62. genhpf-1.0.11.dist-info/LICENSE +21 -0
  63. genhpf-1.0.11.dist-info/METADATA +202 -0
  64. genhpf-1.0.11.dist-info/RECORD +67 -0
  65. genhpf-1.0.11.dist-info/WHEEL +5 -0
  66. genhpf-1.0.11.dist-info/entry_points.txt +6 -0
  67. genhpf-1.0.11.dist-info/top_level.txt +1 -0
genhpf/utils/utils.py ADDED
@@ -0,0 +1,204 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ def buffered_arange(max):
5
+ if not hasattr(buffered_arange, "buf"):
6
+ buffered_arange.buf = torch.LongTensor()
7
+ if max > buffered_arange.buf.numel():
8
+ buffered_arange.buf.resize_(max)
9
+ torch.arange(max, out=buffered_arange.buf)
10
+ return buffered_arange.buf[:max]
11
+
12
+ def sample_word(logits, strategy='topp', **kwargs):
13
+ if strategy == 'topp':
14
+ return sample_top_p(logits, **kwargs)
15
+ else:
16
+ raise NotImplementedError
17
+
18
+ def sample_top_p(logits, k=5, sample=False, **kwargs):
19
+ if sample:
20
+ v, ix = torch.topk(logits, k)
21
+ out = logits.clone()
22
+ out[out < v[..., [-1]]] = -float("Inf")
23
+ probs = F.softmax(out, dim=-1)
24
+
25
+ ix = torch.multinomial(probs.squeeze(0), num_samples=1)
26
+ else:
27
+ _, ix = torch.topk(logits, k=1, dim=-1)
28
+
29
+ next_words = ix
30
+
31
+ return next_words
32
+
33
+ def item(tensor):
34
+ if hasattr(tensor, "item"):
35
+ return tensor.item()
36
+ if hasattr(tensor, "__getitem__"):
37
+ return tensor[0]
38
+ return tensor
39
+
40
+ def apply_to_sample(f, sample):
41
+ if hasattr(sample, "__len__") and len(sample) == 0:
42
+ return {}
43
+
44
+ def _apply(x):
45
+ if torch.is_tensor(x):
46
+ return f(x)
47
+ elif isinstance(x, dict):
48
+ return {key: _apply(value) for key, value in x.items()}
49
+ elif isinstance(x, list):
50
+ return [_apply(x) for x in x]
51
+ elif isinstance(x, tuple):
52
+ return tuple(_apply(x) for x in x)
53
+ elif isinstance(x, set):
54
+ return {_apply(x) for x in x}
55
+ else:
56
+ return x
57
+
58
+ return _apply(sample)
59
+
60
+ def move_to_cuda(sample, device=None):
61
+ device = device or torch.cuda.current_device()
62
+
63
+ def _move_to_cuda(tensor):
64
+ # non_blocking is ignored if tensor is not pinned, so we can always set
65
+ # to True (see github.com/PyTorchLightning/pytorch-lightning/issues/620)
66
+ return tensor.to(device=device, non_blocking=True)
67
+
68
+ return apply_to_sample(_move_to_cuda, sample)
69
+
70
+ def move_to_cpu(sample):
71
+ def _move_to_cpu(tensor):
72
+ # PyTorch has poor support for half tensors (float16) on CPU.
73
+ # Move any such tensors to float32
74
+ if tensor.dtype in {torch.bfloat16, torch.float16}:
75
+ tensor = tensor.to(dtype=torch.float32)
76
+ return tensor.cpu()
77
+ return apply_to_sample(_move_to_cpu, sample)
78
+
79
+
80
+ import logging
81
+ from contextlib import contextmanager
82
+
83
+ import torch
84
+ import torch.nn.functional as F
85
+
86
+ logger = logging.getLogger(__name__)
87
+
88
+ def item(tensor):
89
+ if hasattr(tensor, "item"):
90
+ return tensor.item()
91
+ if hasattr(tensor, "__getitem__"):
92
+ return tensor[0]
93
+ return tensor
94
+
95
+ def get_rng_state():
96
+ state = {"torch_rng_state": torch.get_rng_state()}
97
+ if torch.cuda.is_available():
98
+ state["cuda_rng_state"] = torch.cuda.get_rng_state()
99
+ return state
100
+
101
+ def set_rng_state(state):
102
+ torch.set_rng_state(state["torch_rng_state"])
103
+ if torch.cuda.is_available():
104
+ torch.cuda.set_rng_state(state["cuda_rng_state"])
105
+
106
+ class set_torch_seed(object):
107
+ def __init__(self, seed):
108
+ assert isinstance(seed, int)
109
+ self.rng_state = get_rng_state()
110
+
111
+ torch.manual_seed(seed)
112
+ if torch.cuda.is_available():
113
+ torch.cuda.manual_seed(seed)
114
+ torch.cuda.manual_seed_all(seed)
115
+ torch.backends.cudnn.deterministic = True
116
+ def __enter__(self):
117
+ return self
118
+
119
+ def __exit__(self, *exc):
120
+ set_rng_state(self.rng_state)
121
+
122
+ def apply_to_sample(f, sample):
123
+ if hasattr(sample, "__len__") and len(sample) == 0:
124
+ return {}
125
+
126
+ def _apply(x):
127
+ if torch.is_tensor(x):
128
+ return f(x)
129
+ elif isinstance(x, dict):
130
+ return {key: _apply(value) for key, value in x.items()}
131
+ elif isinstance(x, list):
132
+ return [_apply(x) for x in x]
133
+ elif isinstance(x, tuple):
134
+ return tuple(_apply(x) for x in x)
135
+ elif isinstance(x, set):
136
+ return {_apply(x) for x in x}
137
+ else:
138
+ return x
139
+
140
+ return _apply(sample)
141
+
142
+ def move_to_cuda(sample, device=None):
143
+ device = device or torch.cuda.current_device()
144
+
145
+ def _move_to_cuda(tensor):
146
+ return tensor.to(device=device, non_blocking=True)
147
+
148
+ return apply_to_sample(_move_to_cuda, sample)
149
+
150
+ def move_to_cpu(sample):
151
+ def _move_to_cpu(tensor):
152
+ if tensor.dtype in {torch.bfloat16, torch.float16}:
153
+ tensor = tensor.to(dtype=torch.float32)
154
+ return tensor.cpu()
155
+ return apply_to_sample(_move_to_cpu, sample)
156
+
157
+ def prepare_sample(sample):
158
+ if torch.cuda.is_available():
159
+ sample = move_to_cuda(sample)
160
+
161
+ return sample
162
+
163
+ class CudaEnvironment(object):
164
+ def __init__(self):
165
+ cur_device = torch.cuda.current_device()
166
+ prop = torch.cuda.get_device_properties("cuda:{}".format(cur_device))
167
+ self.name = prop.name
168
+ self.major = prop.major
169
+ self.minor = prop.minor
170
+ self.total_memory_in_GB = prop.total_memory / 1024 / 1024 / 1024
171
+
172
+ @staticmethod
173
+ def pretty_print_cuda_env_list(cuda_env_list):
174
+ """
175
+ Given a list of CudaEnviorments, pretty print them
176
+ """
177
+ num_workers = len(cuda_env_list)
178
+ center = "CUDA enviroments for all {} workers".format(num_workers)
179
+ banner_len = 40 - len(center) // 2
180
+ first_line = "*" * banner_len + center + "*" * banner_len
181
+ logger.info(first_line)
182
+ for r, env in enumerate(cuda_env_list):
183
+ logger.info(
184
+ "rank {:3d}: ".format(r)
185
+ + "capabilities = {:2d}.{:<2d} ; ".format(env.major, env.minor)
186
+ + "total memory = {:.3f} GB ; ".format(env.total_memory_in_GB)
187
+ + "name = {:40s}".format(env.name)
188
+ )
189
+ logger.info(first_line)
190
+
191
+ def has_parameters(module):
192
+ try:
193
+ next(module.parameters())
194
+ return True
195
+ except StopIteration:
196
+ return False
197
+
198
+ @contextmanager
199
+ def rename_logger(logger, new_name):
200
+ old_name = logger.name
201
+ if new_name is not None:
202
+ logger.name = new_name
203
+ yield logger
204
+ logger.name = old_name
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2022 hoon9405
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,202 @@
1
+ Metadata-Version: 2.1
2
+ Name: genhpf
3
+ Version: 1.0.11
4
+ Summary: GenHPF: General Healthcare Predictive Framework with Multi-task Multi-source Learning
5
+ Author-email: Jungwoo Oh <ojw0123@kaist.ac.kr>, Kyunghoon Hur <pacesun@kaist.ac.kr>
6
+ License: MIT license
7
+ Classifier: Intended Audience :: Science/Research
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
10
+ Requires-Python: >=3.10.0
11
+ Description-Content-Type: text/markdown
12
+ License-File: LICENSE
13
+ Requires-Dist: hydra-core==1.3.2
14
+ Requires-Dist: omegaconf==2.3.0
15
+ Requires-Dist: torch==2.6.0
16
+ Requires-Dist: transformers==4.49.0
17
+ Requires-Dist: h5pickle==0.4.2
18
+ Requires-Dist: scikit-learn==1.6.1
19
+ Requires-Dist: pandas==2.2.3
20
+ Requires-Dist: polars==1.17.1
21
+ Requires-Dist: pyarrow==17.0.0
22
+ Provides-Extra: dev
23
+ Requires-Dist: pre-commit; extra == "dev"
24
+ Requires-Dist: black; extra == "dev"
25
+
26
+ # GenHPF : General Healthcare Predictive Framework for Multi-task Multi-source Learning
27
+
28
+ GenHPF is a general healthcare predictive framework, which requires no medical domain knowledge and minimal preprocessing for multiple prediction tasks.
29
+
30
+ Our framework presents a method for embedding any form of EHR systems for prediction tasks without requiring domain-knowledge-based pre-processing, such as medical code mapping and feature selection.
31
+
32
+ This repository provides official Pytorch code to implement GenHPF, a general healthcare predictive framework.
33
+
34
+ # Getting started with GenHPF
35
+ ## STEP 1: Installation
36
+ For developing locally:
37
+ ```bash
38
+ $ pip install -e ./
39
+ ```
40
+
41
+ Otherwise:
42
+ ```bash
43
+ $ pip install genhpf
44
+ ```
45
+
46
+ ## STEP 2: Prepare training data
47
+ ### Preprocessing raw datasets to reproduce GenHPF paper results (GenHPF dataset)
48
+ Download raw datasets and required tools:
49
+ * [MIMIC-III](https://physionet.org/content/mimiciii/1.4/)
50
+ * [MIMIC-IV](https://physionet.org/content/mimiciv/2.0/)
51
+ * [eICU](https://physionet.org/content/eicu-crd/2.0/)
52
+
53
+ Then, run:
54
+ ```bash
55
+ genhpf-preprocess \
56
+ --data $DATA_DIR \
57
+ --ehr {"eicu", "mimiciii", "mimiciv"} \
58
+ --dest $OUTPUT_DIR \
59
+ --first_icu \
60
+ --emb_type {"textbase", "codebase"} \
61
+ --feature {"all_features", "select"} \
62
+ --mortality \
63
+ --long_term_mortality \
64
+ ... # add desired prediction tasks
65
+ ```
66
+ This will output the processed data (`data.h5` and `label.csv`) into `$DATA_DIR/data/` directory.
67
+ For detailed descriptions for each argument, see [src/genhpf/scripts/preprocess/genhpf/README.md](src/genhpf/scripts/preprocess/genhpf/README.md).
68
+ <!-- Note that pre-processing takes about 6hours in 128 cores of AMD EPYC 7502 32-Core Processor, and requires 180GB of RAM. -->
69
+
70
+ Finally, you should prepare data manifest based on the preprocessed data:
71
+ ```bash
72
+ genhpf-manifest $data_dir $label_dir \
73
+ --dest=$output_dir \
74
+ --prefix=$prefix \
75
+ --valid_percent=$valid_percent
76
+ ```
77
+ This will generate the manifest files (e.g., `$prefix_train.tsv`, `$prefix_valid.tsv`, `$prefix_test.tsv`) to `$output_dir` based on the `$data_dir`, which contains `data.h5`, and `$label_dir`, which contains `label.csv`.
78
+ The ratio among train, valid, and test splits is decided by `$valid_percent`.
79
+ Note that this is useful to handle various concepts of training and test datasets.
80
+ For instance, if you want to use multiple datasets (e.g., mimiciv and eicu) for training and evaluate the model on each of the datasets separately, you can perform it by placing the corresponding manifest files (e.g., mimiciv_train, eicu_train, mimiciv_valid, eicu_valid, mimiciv_test, eicu_test) in the same data directory and specifying the following command-line arguments: `dataset.train_subset="mimiciv_train,eicu_train" dataset.combine_train_subsets=true dataset.valid_subset="mimiciv_valid,eicu_valid" dataset.test_subset="mimiciv_test,eicu_test"`.
81
+
82
+ ### Preprocessing MEDS dataset
83
+ We also provide a script to preprocess [MEDS](https://github.com/mmcdermott/MEDS-DEV) dataset with a cohort defined by [ACES](https://github.com/justin13601/ACES) or [MEDS-DEV](https://github.com/mmcdermott/MEDS-DEV) (see Task section) to run with GenHPF.
84
+
85
+ ```bash
86
+ genhpf-preprocess-meds $MEDS_DATA_DIR \
87
+ --cohort $MEDS_LABELS_DIR \
88
+ --metadata_dir $MEDS_METADATA_DIR \
89
+ --output_dir $MEDS_OUTPUT_DIR \
90
+ --workers $NUM_WORKERS
91
+ ```
92
+
93
+ * `$MEDS_DATA_DIR`: a path to the data directory containing MEDS data to be processed. It can be a directory or the exact file path with the file extension (only `.csv` or `.parquet` allowed). If provided with directory, it tries to scan all `*.csv` or `*.parquet` files contained in the directory recursively. See [this](https://github.com/mmcdermott/MEDS-DEV?tab=readme-ov-file#building-a-dataset) if you want to build a new MEDS dataset based on MIMIC-III, MIMIC-IV, and eICU.
94
+ * `$MEDS_LABELS_DIR`: a path to the label directory for a given task, which must be a result of [ACES](https://github.com/justin13601/ACES) or [MEDS-DEV](https://github.com/mmcdermott/MEDS-DEV). It can be a directory or the exact file path that has the same file extension with the MEDS dataset to be processed. The file structure of this cohort directory should be the same with the provided MEDS data directory (`$MEDS_DATA_DIR`) to match each cohort to its corresponding shard data. See [this](https://github.com/mmcdermott/MEDS-DEV?tab=readme-ov-file#extracting-a-task) to extract a cohort for a specific task defined in MEDS-DEV.
95
+ * `$MEDS_METADATA_DIR`: a path to the metadata directory for the input MEDS dataset, expected to contain `codes.parquet`. This is used to retrieve descriptions for codes in MEDS events and convert each code to the retrieved description. Note that if a code has no specific description in `codes.parquet`, it will just treat that code as a plain text and process the event as it is.
96
+ * `$MEDS_OUTPUT_DIR`: directory to save processed outputs.
97
+ * Enabling `--rebase` will renew this directory.
98
+ * `$NUM_WORKERS`: number of parallel workers to multi-process the script.
99
+ * **NOTE: if you encounter this error: _"polars' maximum length reached. consider installing 'polars-u64-idx'"_, please consider using more workers or installing polars-u64-idx by `pip install polars-u64-idx`.**
100
+
101
+ As a result, you will have `.h5` and `.tsv` files that has a following respective structure:
102
+ * `*.h5`
103
+ ```
104
+ *.h5
105
+ └── ${cohort_id}
106
+ └── "ehr"
107
+ ├── "hi"
108
+ │ └── np.ndarray with a shape of (num_events, 3, max_length)
109
+ ├── "time"
110
+ │ └── np.ndarray with a shape of (num_events, )
111
+ └── "label"
112
+ └── binary label (0 or 1) for ${cohort_id} given the defined task
113
+ ```
114
+ * `${cohord_id}`: `${patient_id}_${cohort_number}`, standing for **N-th cohort in the patient**.
115
+ * Numpy array under `"hi"`
116
+ * `[:, 0, :]`: token input ids (i.e., `input_ids`) for the tokenized events.
117
+ * `[:, 1, :]`: token type ids (i.e., `type_ids`) to distinguish where each input token comes from (special tokens such as `[CLS]` or `[SEP]`, column keys, or column values).
118
+ * `[:, 2, :]`: tokens indicting digit places for number type tokens (i.e., `dpe_ids`). It assigns different ids to each of digit places for numeric (integer or float) items.
119
+ * Numpy array under `"time"
120
+ * Elapsed time in minutes from the first event to the last event. We do not this feature currently, but reserve it for future usage (e.g., developing a method to embed events with their temporal features).
121
+ * `*.tsv`
122
+ ```
123
+ patient_id num_events
124
+ 0 10001472_0 13
125
+ 1 10002013_0 47
126
+ 2 10002013_1 46
127
+ … … …
128
+ ```
129
+
130
+ > \[!Note\]
131
+ > GenHPF preprocessing requires a tokenizer from HuggingFace (`emilyalsentzer/Bio_ClinicalBERT`), which means internet access is needed during the initial setup to download the tokenizer.
132
+ > If you are working in a network-restricted setting, you can manually download the tokenizer and load it from a local path.
133
+
134
+ ## STEP 3: Training a new model
135
+ We prepared example configuration files for various models and experimental setups.
136
+ For detailed configurations, please see [configs.py](src/genhpf/configs/configs.py) and each implemented source code (e.g., [genhpf.py](src/genhpf/models/genhpf.py)).
137
+
138
+ ### Examples to process GenHPF dataset
139
+ ### Train a new GenHPF model from scratch:
140
+ ```bash
141
+ genhpf-train \
142
+ dataset.data=??? \
143
+ --config-dir ${GENHPF_DIR}/examples/train/genhpf \
144
+ --config-name genhpf_hierarchical_scr
145
+ ```
146
+ Note that you should fill in `dataset.data=???` with a path to the directory that contains the data manifest files (e.g., `train.tsv`, `valid.tsv`, etc.) for the processed GenHPF data.
147
+
148
+ ### Pre-train and fine-tune a new GenHPF model:
149
+ For pre-training with SimCLR:
150
+ ```bash
151
+ genhpf-train \
152
+ dataset.data=??? \
153
+ --config-dir ${GENHPF_DIR}/examples/pretrain/simclr/genhpf \
154
+ --config-name genhpf_hierarchical_pt
155
+ ```
156
+ For fine-tuning:
157
+ ```bash
158
+ genhpf-train \
159
+ dataset.data=??? \
160
+ model.from_pretrained=${/path/to/the/pretrained/checkpoint.pt} \
161
+ --config-dir ${GENHPF_DIR}/examples/train/genhpf \
162
+ --config-name genhpf_hierarchical_ft
163
+ ```
164
+
165
+ ### Examples to process MEDS dataset
166
+ ```bash
167
+ genhpf-train \
168
+ dataset.data=??? \
169
+ --config-dir ${GENHPF_DIR}/examples/train/genhpf \
170
+ --config-name meds_hierarchical_scr
171
+ ```
172
+ Note that you should fill in `dataset.data=???` with a path to the directory that contains the data manifest files (e.g., `train.tsv`, `tuning.tsv`, etc.) for the processed MEDS data (i.e., `$MEDS_OUTPUT_DIR`).
173
+
174
+ For doing inference on MEDS dataset while outputting prediction results to evaluate the model using [meds-evaluation](https://github.com/kamilest/meds-evaluation):
175
+ ```bash
176
+ genhpf-test \
177
+ dataset.data=??? \
178
+ meds.output_predictions=true \
179
+ meds.labels_dir=$MEDS_LABELS_DIR \
180
+ meds.output_dir=$OUTPUT_DIR \
181
+ checkpoint.load_checkpoint=${/path/to/the/trained/checkpoint.pt} \
182
+ --config-dir ${GENHPF_DIR}/examples/test/genhpf \
183
+ --config-name meds_hierarchical
184
+ ```
185
+ This script will load the model weights from `${/path/to/the/trained/checkpoint.pt}`, process the data specified by `dataset.data`, and output the prediction results for the test subset as a single parquet file to `$OUTPUT_DIR` directory.
186
+ Note that the data directory `dataset.data` should contain the directory for the test data with its manifest file (e.g., `held_out/*.h5` with `held_out.tsv`), where the name of the test subset is specified by `dataset.test_subset` config.
187
+
188
+ ## Citation
189
+ If you find GenHPF useful for your research and applications, please cite using this BibTeX:
190
+ ```bibtex
191
+
192
+ @article{hur2023genhpf,
193
+ title={GenHPF: General Healthcare Predictive Framework for Multi-task Multi-source Learning},
194
+ author={Hur, Kyunghoon and Oh, Jungwoo and Kim, Junu and Kim, Jiyoun and Lee, Min Jae and Cho, Eunbyeol and Moon, Seong-Eun and Kim, Young-Hak and Atallah, Louis and Choi, Edward},
195
+ journal={IEEE Journal of Biomedical and Health Informatics},
196
+ year={2023},
197
+ publisher={IEEE}
198
+ }
199
+ ```
200
+
201
+ # License
202
+ This repository is MIT-lincensed.
@@ -0,0 +1,67 @@
1
+ genhpf/__init__.py,sha256=uh6oTFMxEX_AwRqlfDmNeS3kU4QhY-KXG6nsQ2kjWNo,219
2
+ genhpf/trainer.py,sha256=v8wadlwI_HCopbCyEkaHw_abu2MscPibJjBWMg5pFw0,13339
3
+ genhpf/configs/__init__.py,sha256=L0heECTJaH5SyESeCWxbnpjAnJAIh8z05M8--DlQI8k,393
4
+ genhpf/configs/config.yaml,sha256=0Y8eL7b8lh3ZVSO8h7JhTPHi_CcPQ69zBv-2iTocjAg,63
5
+ genhpf/configs/configs.py,sha256=WpO_EzUoM32sKVtiVV4ynKrMGSt1Crdjf1C0Sc9Rhfg,10723
6
+ genhpf/configs/constants.py,sha256=B4mzuJpD4V-saixyy2n5LMGQ6brLdqHHlkmCvYK2AGo,855
7
+ genhpf/configs/initialize.py,sha256=cvhFASxB-QoapXpjnQvivDS9U6e_LeAjB1aa6e1eoPo,1981
8
+ genhpf/configs/utils.py,sha256=v0zxL54VJNd7Y9MxRCf3KBLg7fY3w4kGUvUfI43ikJU,1091
9
+ genhpf/criterions/__init__.py,sha256=-oEk7_R8plPSo76-BWmODg610bp6IGHIOMLlTu6ZU3k,2452
10
+ genhpf/criterions/binary_cross_entropy.py,sha256=yqZ9ZwsnmGYaVRWBhZsXnTpGfCrhFsn6XO8pftDzzF4,4353
11
+ genhpf/criterions/binary_cross_entropy_with_logits.py,sha256=_MpR_puhunBh9Lu4slc34H50dnUIzqkLMuZPfSgN6so,4502
12
+ genhpf/criterions/criterion.py,sha256=o2oaFvADYAxjymHR8diEFC39ij8k0HZw03KAMm12OAw,3274
13
+ genhpf/criterions/cross_entropy.py,sha256=LuWxInQyEpM1DF1mkyhI_QyZKGrQ2J4LjP8_mlFigsc,7946
14
+ genhpf/criterions/multi_task_criterion.py,sha256=6ZKf6AVKHPE_wNikqEyQ7voFq53HFKFUod6DqxSc9aA,7226
15
+ genhpf/criterions/simclr_criterion.py,sha256=Dhed8JrR-aFNekSeZPKHO8pcCzHZUoBnd3LmLQ0-Y1w,3084
16
+ genhpf/criterions/wav2vec2_criterion.py,sha256=ZGMYqlLBz3AEgB1e_Ip4MbhBJuWjPax6TXt62__eXnE,4891
17
+ genhpf/datasets/__init__.py,sha256=exbMv-8_nQUqgebzzbLmyGILaWanIoiEud__C48F7iI,3613
18
+ genhpf/datasets/dataset.py,sha256=lwD6MqLAb6TicD0CL9Nv2j0yJD5hrRjRgBMO3vAVYBA,4586
19
+ genhpf/datasets/genhpf_dataset.py,sha256=veuqoKZb_x2wXVLQ4yFb36Fbh55IPPk-er_4gZ4aluw,16139
20
+ genhpf/datasets/meds_dataset.py,sha256=a6uaxH7MkX4gkaGyJr1RhGrL9WpqR5Bae1PJUUvA_lo,7994
21
+ genhpf/loggings/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
+ genhpf/loggings/meters.py,sha256=ECdJTwFHx_4D22iNbv9VRxlh9iibX8aU9QeHPkqNmXQ,10795
23
+ genhpf/loggings/metrics.py,sha256=3CSBA5C3bd-G-zNer7BeOqSZj-tn6twbpLqAlt-FQ_A,3935
24
+ genhpf/loggings/progress_bar.py,sha256=9-24WAFDsp6WSS-JncnQtQMwo7DnNEYakAt7a8pkhF0,14140
25
+ genhpf/models/__init__.py,sha256=EG4YnL8Uiem8iUNm72euHJlim0IZj3inzFVFCFOvPCE,2223
26
+ genhpf/models/genhpf.py,sha256=Y9f8H3fgUm1H-QWTnRzcQMu1Pkl6i0ZNNRuSmZZ6Zh0,9712
27
+ genhpf/models/genhpf_mlm.py,sha256=rExPpm1HDjljAjgFbYx2bgS6VSaIKF6-P7VJcq6YLB0,1882
28
+ genhpf/models/genhpf_predictor.py,sha256=i-XIh7S3ozpB_r4JZI27sfdnbANyQYpBIOrDDgsiWvc,2163
29
+ genhpf/models/genhpf_simclr.py,sha256=Iuqx0fy0AQurkTk0e5hEv12eJyeGGGiQJiRKXGgOTnI,1629
30
+ genhpf/models/genhpf_wav2vec2.py,sha256=lY-0Bn7RavH5yNvtbx1vb2cfcdQ8ON2DYfuTiz0X2DQ,10959
31
+ genhpf/modules/__init__.py,sha256=lbuveico2NtfEbi96ykqboAe-qxrCb0AvkqrrYgBuqg,402
32
+ genhpf/modules/gather_layer.py,sha256=dplHmFmAV05nl-bsLA1OG2OtPKHxaWj0Z9GNTkNFvBw,657
33
+ genhpf/modules/grad_multiply.py,sha256=rYNanKk6jjewppDgZ3mit7ndw1QklANctpmw6_nH_5M,266
34
+ genhpf/modules/gumbel_vector_quantizer.py,sha256=JaKQJ6CB-gMRYP60ATWid3niygpJ7QenNRDiFyxIH30,7001
35
+ genhpf/modules/identity_layer.py,sha256=uGRqfRBNhc-ha5wNrqawh5E1pw5Lm9GDGi1stju-UF4,158
36
+ genhpf/modules/layer_norm.py,sha256=-aVKThi1pWvVMbMAzyQG1co6MHPBCUZgxWJKYzIqsPQ,902
37
+ genhpf/modules/positional_encoding.py,sha256=Rf_qHdQArljEggRO4EHufc_JHq9-i44Oog1w9Bh51DQ,754
38
+ genhpf/scripts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
39
+ genhpf/scripts/test.py,sha256=DZPiZa-Tm6kKLcK3R1EH82gjq4Hbl098IAY4kA3fQxg,10288
40
+ genhpf/scripts/train.py,sha256=AoaufZfxxYsD2pSnZvWxRVxpsSv_SXGgzTStx-APiMw,13099
41
+ genhpf/scripts/preprocess/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
42
+ genhpf/scripts/preprocess/manifest.py,sha256=ZIK16e4vs_cS2K_tM1GaT38hc1nBHk6JB9Uga6OjgU4,2711
43
+ genhpf/scripts/preprocess/preprocess_meds.py,sha256=QWG5HCNvO1yuGrQM1SsrnWM3Zn18zcPpbeWbXovKhrs,26657
44
+ genhpf/scripts/preprocess/genhpf/README.md,sha256=qtpM_ABJk5yI8xbsUj1sZ71yX5bybx9ZvAymo0Lh5Vc,2877
45
+ genhpf/scripts/preprocess/genhpf/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
46
+ genhpf/scripts/preprocess/genhpf/main.py,sha256=EF3sce0ltowMHIGK7zLEQEOnzOWQ_WJxoBowknHV3mQ,6161
47
+ genhpf/scripts/preprocess/genhpf/manifest.py,sha256=uHx0POSs9-ZB8Vtib7rPJ6hgDVJ1CBN6Ccfa4PpqmnM,2663
48
+ genhpf/scripts/preprocess/genhpf/sample_dataset.py,sha256=JzjMY2ynIYoWWtRlBG9Hxv6EoF27jJHyd3VYfqsM0Xs,5569
49
+ genhpf/scripts/preprocess/genhpf/ehrs/__init__.py,sha256=8bA4Pk0ylLIpwFQKEx6lis0k_inh4owF2SlHjHhKkeE,895
50
+ genhpf/scripts/preprocess/genhpf/ehrs/ehr.py,sha256=xWr7JIg2UeoYyxN36FI75ZnMJTyK6pXfiB7AUkIdDDw,38221
51
+ genhpf/scripts/preprocess/genhpf/ehrs/eicu.py,sha256=5w7cE9ajinpgnyRMtCkDOP82YDRc-TnObQaRXd1Ho2k,22925
52
+ genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py,sha256=lnDN8ZmXgSiDU48Z5kBUyqCezxA7Cf0yM7R6SFRRbDk,36643
53
+ genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py,sha256=q2CDXAkkHc8dA5VSWCtg7S4AbEmwKhzbohPx8961G-g,25339
54
+ genhpf/scripts/preprocess/genhpf/utils/__init__.py,sha256=N2AAYli0M8KTOQ9KCkNUQu5iMhIJmDV0xqQ_IAwHRvE,76
55
+ genhpf/scripts/preprocess/genhpf/utils/utils.py,sha256=DBauYrBtjI36hQNqZXC6kID77KZ3sYrFSOIW5_2kA0I,480
56
+ genhpf/utils/checkpoint_utils.py,sha256=Le2KmvwTxDurkqy86yD8F505dz1k3aSB-OE09_sUTsQ,6161
57
+ genhpf/utils/data_utils.py,sha256=cYJe-SednIvYer_VffCwwFByQ5SAEBQuak4NDxDK_GM,5098
58
+ genhpf/utils/distributed_utils.py,sha256=000xKlw8SLoSH16o6n2bB3eueGR0aVD_DufPYESi5k0,17654
59
+ genhpf/utils/file_io.py,sha256=hnZXdMtAibfFDoIfn-SDusl-v7ZImeUEh0eD2MIxbG4,4919
60
+ genhpf/utils/pdb.py,sha256=400rk1pVfOpVpzKIFHnTRlZ2VCtBqRh9G-pRRwu2Oqo,930
61
+ genhpf/utils/utils.py,sha256=BoC_7Gz8uCHbUBCpcXGBMD-5irApi_6xM7nU-2ac4aA,6176
62
+ genhpf-1.0.11.dist-info/LICENSE,sha256=VK_rvhY2Xi_DAIZHtauni5O9-1_do5SNWjrskv4amg8,1065
63
+ genhpf-1.0.11.dist-info/METADATA,sha256=8LF-6H-2SUqB1MO0vhwJ8TBdnC_BY-V8bOhTZc7svxo,10916
64
+ genhpf-1.0.11.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
65
+ genhpf-1.0.11.dist-info/entry_points.txt,sha256=Wp94VV2w9KasBDLaluLM5EnjLgjNOAQVu44wKRDAwmQ,288
66
+ genhpf-1.0.11.dist-info/top_level.txt,sha256=lk846Vmnvydb6UZn8xmowj60nkrZYexNOGGnPM-IbhA,7
67
+ genhpf-1.0.11.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (75.3.2)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,6 @@
1
+ [console_scripts]
2
+ genhpf-manifest = genhpf.scripts.manifest:main
3
+ genhpf-preprocess = genhpf.scripts.preprocess.genhpf.main:main
4
+ genhpf-preprocess-meds = genhpf.scripts.preprocess.preprocess_meds:main
5
+ genhpf-test = genhpf.scripts.test:cli_main
6
+ genhpf-train = genhpf.scripts.train:cli_main
@@ -0,0 +1 @@
1
+ genhpf