genhpf 1.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genhpf/__init__.py +9 -0
- genhpf/configs/__init__.py +23 -0
- genhpf/configs/config.yaml +8 -0
- genhpf/configs/configs.py +240 -0
- genhpf/configs/constants.py +29 -0
- genhpf/configs/initialize.py +58 -0
- genhpf/configs/utils.py +29 -0
- genhpf/criterions/__init__.py +74 -0
- genhpf/criterions/binary_cross_entropy.py +114 -0
- genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
- genhpf/criterions/criterion.py +87 -0
- genhpf/criterions/cross_entropy.py +202 -0
- genhpf/criterions/multi_task_criterion.py +177 -0
- genhpf/criterions/simclr_criterion.py +84 -0
- genhpf/criterions/wav2vec2_criterion.py +130 -0
- genhpf/datasets/__init__.py +84 -0
- genhpf/datasets/dataset.py +109 -0
- genhpf/datasets/genhpf_dataset.py +451 -0
- genhpf/datasets/meds_dataset.py +232 -0
- genhpf/loggings/__init__.py +0 -0
- genhpf/loggings/meters.py +374 -0
- genhpf/loggings/metrics.py +155 -0
- genhpf/loggings/progress_bar.py +445 -0
- genhpf/models/__init__.py +73 -0
- genhpf/models/genhpf.py +244 -0
- genhpf/models/genhpf_mlm.py +64 -0
- genhpf/models/genhpf_predictor.py +73 -0
- genhpf/models/genhpf_simclr.py +58 -0
- genhpf/models/genhpf_wav2vec2.py +304 -0
- genhpf/modules/__init__.py +15 -0
- genhpf/modules/gather_layer.py +23 -0
- genhpf/modules/grad_multiply.py +12 -0
- genhpf/modules/gumbel_vector_quantizer.py +204 -0
- genhpf/modules/identity_layer.py +8 -0
- genhpf/modules/layer_norm.py +27 -0
- genhpf/modules/positional_encoding.py +24 -0
- genhpf/scripts/__init__.py +0 -0
- genhpf/scripts/preprocess/__init__.py +0 -0
- genhpf/scripts/preprocess/genhpf/README.md +75 -0
- genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
- genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
- genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
- genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
- genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
- genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
- genhpf/scripts/preprocess/genhpf/main.py +175 -0
- genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
- genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
- genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
- genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
- genhpf/scripts/preprocess/manifest.py +83 -0
- genhpf/scripts/preprocess/preprocess_meds.py +674 -0
- genhpf/scripts/test.py +264 -0
- genhpf/scripts/train.py +365 -0
- genhpf/trainer.py +370 -0
- genhpf/utils/checkpoint_utils.py +171 -0
- genhpf/utils/data_utils.py +130 -0
- genhpf/utils/distributed_utils.py +497 -0
- genhpf/utils/file_io.py +170 -0
- genhpf/utils/pdb.py +38 -0
- genhpf/utils/utils.py +204 -0
- genhpf-1.0.11.dist-info/LICENSE +21 -0
- genhpf-1.0.11.dist-info/METADATA +202 -0
- genhpf-1.0.11.dist-info/RECORD +67 -0
- genhpf-1.0.11.dist-info/WHEEL +5 -0
- genhpf-1.0.11.dist-info/entry_points.txt +6 -0
- genhpf-1.0.11.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
import sys
|
|
5
|
+
|
|
6
|
+
from ehrs import EHR_REGISTRY
|
|
7
|
+
from pyspark.sql import SparkSession
|
|
8
|
+
|
|
9
|
+
logging.basicConfig(
|
|
10
|
+
format="%(asctime)s | %(levelname)s | %(message)s",
|
|
11
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
|
12
|
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
|
13
|
+
stream=sys.stdout,
|
|
14
|
+
)
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_parser():
|
|
19
|
+
parser = argparse.ArgumentParser()
|
|
20
|
+
# task
|
|
21
|
+
parser.add_argument(
|
|
22
|
+
"--task",
|
|
23
|
+
type=str,
|
|
24
|
+
default=None,
|
|
25
|
+
help="specific task from the all_features pipeline." "if not set, run the all_features steps.",
|
|
26
|
+
)
|
|
27
|
+
parser.add_argument("--dest", default="outputs", type=str, metavar="DIR", help="output directory")
|
|
28
|
+
|
|
29
|
+
# data
|
|
30
|
+
parser.add_argument(
|
|
31
|
+
"--ehr",
|
|
32
|
+
type=str,
|
|
33
|
+
required=True,
|
|
34
|
+
choices=["mimiciii", "mimiciv", "eicu"],
|
|
35
|
+
help="name of the ehr system to be processed.",
|
|
36
|
+
)
|
|
37
|
+
parser.add_argument(
|
|
38
|
+
"--data",
|
|
39
|
+
metavar="DIR",
|
|
40
|
+
default=None,
|
|
41
|
+
help="directory containing data files of the given ehr (--ehr)."
|
|
42
|
+
"if not given, try to download from the internet.",
|
|
43
|
+
)
|
|
44
|
+
parser.add_argument(
|
|
45
|
+
"--ext",
|
|
46
|
+
type=str,
|
|
47
|
+
default=None,
|
|
48
|
+
help="extension for ehr data to look for. " "if not given, try to infer from --data",
|
|
49
|
+
)
|
|
50
|
+
parser.add_argument(
|
|
51
|
+
"--ccs",
|
|
52
|
+
type=str,
|
|
53
|
+
default=None,
|
|
54
|
+
help="path to `ccs_multi_dx_tool_2015.csv`" "if not given, try to download from the internet.",
|
|
55
|
+
)
|
|
56
|
+
parser.add_argument(
|
|
57
|
+
"--gem",
|
|
58
|
+
type=str,
|
|
59
|
+
default=None,
|
|
60
|
+
help="path to `icd10cmtoicd9gem.csv`" "if not given, try to download from the internet.",
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
parser.add_argument(
|
|
64
|
+
"-c", "--cache", action="store_true", help="whether to load data from cache if exists"
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# misc
|
|
68
|
+
parser.add_argument("--max_event_size", type=int, default=256, help="max event size to crop to")
|
|
69
|
+
parser.add_argument(
|
|
70
|
+
"--min_event_size",
|
|
71
|
+
type=int,
|
|
72
|
+
default=5,
|
|
73
|
+
help="min event size to skip small samples",
|
|
74
|
+
)
|
|
75
|
+
parser.add_argument("--min_age", type=int, default=18, help="min age to skip too young patients")
|
|
76
|
+
parser.add_argument("--max_age", type=int, default=None, help="max age to skip too old patients")
|
|
77
|
+
parser.add_argument("--obs_size", type=int, default=12, help="observation window size by the hour")
|
|
78
|
+
parser.add_argument("--gap_size", type=int, default=12, help="time gap window size by the hour")
|
|
79
|
+
parser.add_argument("--pred_size", type=int, default=24, help="prediction window size by the hour")
|
|
80
|
+
parser.add_argument(
|
|
81
|
+
"--long_term_pred_size",
|
|
82
|
+
type=int,
|
|
83
|
+
default=336,
|
|
84
|
+
help="prediction window size by the hour (for long term mortality task)",
|
|
85
|
+
)
|
|
86
|
+
parser.add_argument(
|
|
87
|
+
"--first_icu",
|
|
88
|
+
action="store_true",
|
|
89
|
+
help="whether to use only the first icu or not",
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# tasks
|
|
93
|
+
parser.add_argument("--mortality", action="store_true", help="whether to include mortality task or not")
|
|
94
|
+
parser.add_argument(
|
|
95
|
+
"--long_term_mortality",
|
|
96
|
+
action="store_true",
|
|
97
|
+
help="whether to include long term mortality task or not",
|
|
98
|
+
)
|
|
99
|
+
parser.add_argument("--los_3day", action="store_true", help="whether to include 3-day los task or not")
|
|
100
|
+
parser.add_argument("--los_7day", action="store_true", help="whether to include 7-day los task or not")
|
|
101
|
+
parser.add_argument(
|
|
102
|
+
"--readmission", action="store_true", help="whether to include readmission task or not"
|
|
103
|
+
)
|
|
104
|
+
parser.add_argument(
|
|
105
|
+
"--final_acuity", action="store_true", help="whether to include final acuity task or not"
|
|
106
|
+
)
|
|
107
|
+
parser.add_argument(
|
|
108
|
+
"--imminent_discharge", action="store_true", help="whether to include imminent discharge task or not"
|
|
109
|
+
)
|
|
110
|
+
parser.add_argument("--diagnosis", action="store_true", help="whether to include diagnosis task or not")
|
|
111
|
+
parser.add_argument("--creatinine", action="store_true", help="whether to include creatinine task or not")
|
|
112
|
+
parser.add_argument("--bilirubin", action="store_true", help="whether to include bilirubin task or not")
|
|
113
|
+
parser.add_argument("--platelets", action="store_true", help="whether to include platelets task or not")
|
|
114
|
+
parser.add_argument(
|
|
115
|
+
"--wbc", action="store_true", help="whether to include blood white blood cell count task or not"
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
parser.add_argument(
|
|
119
|
+
"--chunk_size",
|
|
120
|
+
type=int,
|
|
121
|
+
default=1024,
|
|
122
|
+
help="chunk size to read large csv files",
|
|
123
|
+
)
|
|
124
|
+
parser.add_argument("--bins", type=int, default=20, help="num buckets to bin time intervals by")
|
|
125
|
+
|
|
126
|
+
parser.add_argument(
|
|
127
|
+
"--max_event_token_len", type=int, default=192, help="max token length for each event (Hierarchical)"
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
parser.add_argument(
|
|
131
|
+
"--max_patient_token_len", type=int, default=8192, help="max token length for each patient (Flatten)"
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
parser.add_argument("--num_threads", type=int, default=8, help="number of threads to use")
|
|
135
|
+
|
|
136
|
+
# CodeEmb / Feature select
|
|
137
|
+
parser.add_argument(
|
|
138
|
+
"--emb_type",
|
|
139
|
+
type=str,
|
|
140
|
+
choices=["codebase", "textbase"],
|
|
141
|
+
default="textbase",
|
|
142
|
+
help="feature embedding type, codebase model = [SAND, Rajikomar], textbase model = [GenHPF, DescEmb]",
|
|
143
|
+
)
|
|
144
|
+
parser.add_argument(
|
|
145
|
+
"--feature",
|
|
146
|
+
choices=["select", "all_features"],
|
|
147
|
+
default="all_features",
|
|
148
|
+
help="pre-define feature select or all_features table use",
|
|
149
|
+
)
|
|
150
|
+
parser.add_argument("--bucket_num", type=int, default=10, help="feature bucket num")
|
|
151
|
+
return parser
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def main():
|
|
155
|
+
parser = get_parser()
|
|
156
|
+
args = parser.parse_args()
|
|
157
|
+
|
|
158
|
+
if not os.path.exists(args.dest):
|
|
159
|
+
os.makedirs(args.dest)
|
|
160
|
+
|
|
161
|
+
ehr = EHR_REGISTRY[args.ehr](args)
|
|
162
|
+
|
|
163
|
+
spark = (
|
|
164
|
+
SparkSession.builder.master(f"local[{args.num_threads}]")
|
|
165
|
+
.config("spark.driver.memory", "100g")
|
|
166
|
+
.config("spark.driver.maxResultSize", "10g")
|
|
167
|
+
.config("spark.network.timeout", "100s")
|
|
168
|
+
.appName("Main_Preprocess")
|
|
169
|
+
.getOrCreate()
|
|
170
|
+
)
|
|
171
|
+
ehr.run_pipeline(spark)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
if __name__ == "__main__":
|
|
175
|
+
main()
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import os
|
|
3
|
+
import random
|
|
4
|
+
|
|
5
|
+
import h5pickle
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def get_parser():
|
|
10
|
+
parser = argparse.ArgumentParser()
|
|
11
|
+
parser.add_argument(
|
|
12
|
+
"root", metavar="DIR", type=str, help="root directory containing the data.h5 and label.csv"
|
|
13
|
+
)
|
|
14
|
+
parser.add_argument("--dest", default=".", type=str, metavar="DIR", help="output directory")
|
|
15
|
+
parser.add_argument("--prefix", default="", type=str, help="prefix for the output files")
|
|
16
|
+
parser.add_argument(
|
|
17
|
+
"--valid-percent",
|
|
18
|
+
default=0.1,
|
|
19
|
+
type=float,
|
|
20
|
+
metavar="D",
|
|
21
|
+
help="percentage of data to use as validation and test set (between 0 and 0.5)",
|
|
22
|
+
)
|
|
23
|
+
parser.add_argument("--seed", default=42, type=int, metavar="N", help="random seed")
|
|
24
|
+
return parser
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def main(args):
|
|
28
|
+
assert 0 <= args.valid_percent <= 0.5, (
|
|
29
|
+
f"Invalid valid-percent: {args.valid_percent}. " "Must be between 0 and 0.5."
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
if len(args.prefix) > 0 and not args.prefix.endswith("_"):
|
|
33
|
+
args.prefix += "_"
|
|
34
|
+
|
|
35
|
+
root_path = os.path.realpath(args.root)
|
|
36
|
+
data_path = os.path.join(root_path, "data.h5")
|
|
37
|
+
label_path = os.path.join(root_path, "label.csv")
|
|
38
|
+
rand = random.Random(args.seed)
|
|
39
|
+
|
|
40
|
+
if not os.path.exists(args.dest):
|
|
41
|
+
os.makedirs(args.dest)
|
|
42
|
+
|
|
43
|
+
with (
|
|
44
|
+
open(os.path.join(args.dest, f"{args.prefix}train.tsv"), "w") as train_f,
|
|
45
|
+
open(os.path.join(args.dest, f"{args.prefix}valid.tsv"), "w") as valid_f,
|
|
46
|
+
open(os.path.join(args.dest, f"{args.prefix}test.tsv"), "w") as test_f,
|
|
47
|
+
):
|
|
48
|
+
print(data_path, file=train_f)
|
|
49
|
+
print(label_path, file=train_f)
|
|
50
|
+
print(data_path, file=valid_f)
|
|
51
|
+
print(label_path, file=valid_f)
|
|
52
|
+
print(data_path, file=test_f)
|
|
53
|
+
print(label_path, file=test_f)
|
|
54
|
+
|
|
55
|
+
def write(subjects, dest, split=None):
|
|
56
|
+
for subject in tqdm(subjects, total=len(subjects), desc=split):
|
|
57
|
+
print(subject, file=dest)
|
|
58
|
+
|
|
59
|
+
data = h5pickle.File(data_path, "r")["ehr"]
|
|
60
|
+
subjects = list(data.keys())
|
|
61
|
+
rand.shuffle(subjects)
|
|
62
|
+
|
|
63
|
+
valid_len = int(len(subjects) * args.valid_percent)
|
|
64
|
+
test_len = int(len(subjects) * args.valid_percent)
|
|
65
|
+
train_len = len(subjects) - valid_len - test_len
|
|
66
|
+
|
|
67
|
+
train = subjects[:train_len]
|
|
68
|
+
valid = subjects[train_len : train_len + valid_len]
|
|
69
|
+
test = subjects[train_len + valid_len :]
|
|
70
|
+
|
|
71
|
+
write(train, train_f, split=f"{args.prefix}train")
|
|
72
|
+
write(valid, valid_f, split=f"{args.prefix}valid")
|
|
73
|
+
write(test, test_f, split=f"{args.prefix}test")
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
if __name__ == "__main__":
|
|
77
|
+
parser = get_parser()
|
|
78
|
+
args = parser.parse_args()
|
|
79
|
+
main(args)
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
import h5py
|
|
5
|
+
import torch
|
|
6
|
+
from torch.nn.functional import one_hot
|
|
7
|
+
from torch.nn.utils.rnn import pad_sequence
|
|
8
|
+
from torch.utils.data import Dataset
|
|
9
|
+
|
|
10
|
+
"""
|
|
11
|
+
Sample Dataset for Integrated-EHR-Pipeline
|
|
12
|
+
NOTE: The preprocessed token indices are stored as np.int16 type for efficienty,
|
|
13
|
+
so overflow can be caused when vocab size > 32767.
|
|
14
|
+
- ehr.cohort.labeled.index: dataframe pickle
|
|
15
|
+
- index: should be resetted by `df.reset_index()`
|
|
16
|
+
- [hi_start:hi_end] is the range of indices for the patient's history
|
|
17
|
+
- split: one of ['train', 'val', 'test']
|
|
18
|
+
- ehr.data: .np file with np.int16 type, (num_total_events, 3, max_word_len)
|
|
19
|
+
- second dimension
|
|
20
|
+
- 0: input_ids
|
|
21
|
+
- 1: type_ids
|
|
22
|
+
- 2: dpe_ids
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def get_parser():
|
|
27
|
+
parser = argparse.ArgumentParser()
|
|
28
|
+
parser.add_argument(
|
|
29
|
+
"--ehr",
|
|
30
|
+
type=str,
|
|
31
|
+
required=True,
|
|
32
|
+
choices=["mimiciii", "mimiciv", "eicu"],
|
|
33
|
+
help="Name of the EHR dataset",
|
|
34
|
+
)
|
|
35
|
+
parser.add_argument("--data", type=str, required=True, help="Path to the preprocessed data")
|
|
36
|
+
parser.add_argument(
|
|
37
|
+
"--pred_target",
|
|
38
|
+
type=str,
|
|
39
|
+
required=True,
|
|
40
|
+
choices=[
|
|
41
|
+
"mortality",
|
|
42
|
+
"readmission",
|
|
43
|
+
"los_3day",
|
|
44
|
+
"los_7day",
|
|
45
|
+
"final_acuity",
|
|
46
|
+
"imminent_discharge",
|
|
47
|
+
"diagnosis",
|
|
48
|
+
],
|
|
49
|
+
help="Prediction target",
|
|
50
|
+
)
|
|
51
|
+
parser.add_argument("--max_event_size", type=int, default=256, help="max event size to crop to")
|
|
52
|
+
parser.add_argument(
|
|
53
|
+
"--max_event_token_len", type=int, default=128, help="max token length for each event (Hierarchical)"
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
parser.add_argument(
|
|
57
|
+
"--max_patient_token_len", type=int, default=8192, help="max token length for each patient (Flatten)"
|
|
58
|
+
)
|
|
59
|
+
return parser
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class BaseEHRDataset(Dataset):
|
|
63
|
+
def __init__(self, args, split):
|
|
64
|
+
super().__init__()
|
|
65
|
+
|
|
66
|
+
self.args = args
|
|
67
|
+
self.data_path = os.path.join(args.data, f"{args.ehr}.h5")
|
|
68
|
+
|
|
69
|
+
self.data = h5py.File(self.data_path, "r")["ehr"]
|
|
70
|
+
self.keys = []
|
|
71
|
+
for key in self.data.keys():
|
|
72
|
+
if self.data[key].attrs["split"] == split:
|
|
73
|
+
self.keys.append(key)
|
|
74
|
+
self.pred_target = args.pred_target
|
|
75
|
+
|
|
76
|
+
self.num_classes = {
|
|
77
|
+
"mimiciii": {
|
|
78
|
+
"mortality": 2,
|
|
79
|
+
"readmission": 2,
|
|
80
|
+
"los_3day": 2,
|
|
81
|
+
"los_7day": 2,
|
|
82
|
+
"final_acuity": 18,
|
|
83
|
+
"imminent_discharge": 18,
|
|
84
|
+
"diagnosis": 18,
|
|
85
|
+
},
|
|
86
|
+
"mimiciv": {
|
|
87
|
+
"mortality": 2,
|
|
88
|
+
"readmission": 2,
|
|
89
|
+
"los_3day": 2,
|
|
90
|
+
"los_7day": 2,
|
|
91
|
+
"final_acuity": 14,
|
|
92
|
+
"imminent_discharge": 14,
|
|
93
|
+
"diagnosis": 18,
|
|
94
|
+
},
|
|
95
|
+
"eicu": {
|
|
96
|
+
"mortality": 2,
|
|
97
|
+
"readmission": 2,
|
|
98
|
+
"los_3day": 2,
|
|
99
|
+
"los_7day": 2,
|
|
100
|
+
"final_acuity": 9,
|
|
101
|
+
"imminent_discharge": 9,
|
|
102
|
+
"diagnosis": 18,
|
|
103
|
+
},
|
|
104
|
+
}[self.args.ehr][self.args.pred_target]
|
|
105
|
+
|
|
106
|
+
def __len__(self):
|
|
107
|
+
return len(self.keys)
|
|
108
|
+
|
|
109
|
+
def __getitem__(self, idx):
|
|
110
|
+
raise NotImplementedError()
|
|
111
|
+
|
|
112
|
+
def collate_fn(self, out):
|
|
113
|
+
ret = dict()
|
|
114
|
+
if len(out) == 1:
|
|
115
|
+
for k, v in out[0].items():
|
|
116
|
+
if k == "label":
|
|
117
|
+
ret[k] = (
|
|
118
|
+
one_hot(torch.LongTensor([v]), self.num_classes)
|
|
119
|
+
.float()
|
|
120
|
+
.reshape((1, self.num_classes))
|
|
121
|
+
)
|
|
122
|
+
else:
|
|
123
|
+
ret[k] = v.unsqueeze(0)
|
|
124
|
+
else:
|
|
125
|
+
for k, v in out[0].items():
|
|
126
|
+
if k == "label":
|
|
127
|
+
ret[k] = one_hot(torch.LongTensor([i[k] for i in out]), self.num_classes).float()
|
|
128
|
+
else:
|
|
129
|
+
ret[k] = pad_sequence([i[k] for i in out], batch_first=True)
|
|
130
|
+
return ret
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class HierarchicalEHRDataset(BaseEHRDataset):
|
|
134
|
+
def __init__(self, args, split):
|
|
135
|
+
super().__init__(args, split)
|
|
136
|
+
|
|
137
|
+
def __getitem__(self, idx):
|
|
138
|
+
# NOTE: Warning occurs when converting np.int16 read-only array into tensor, but ignoreable
|
|
139
|
+
|
|
140
|
+
data = self.data[self.keys[idx]]["hi"]
|
|
141
|
+
label = self.data[self.keys[idx]].attrs[self.pred_target]
|
|
142
|
+
return {
|
|
143
|
+
"input_ids": torch.IntTensor(data[:, 0, :]),
|
|
144
|
+
"type_ids": torch.IntTensor(data[:, 1, :]),
|
|
145
|
+
"dpe_ids": torch.IntTensor(data[:, 2, :]),
|
|
146
|
+
"label": label,
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class FlattenEHRDataset(BaseEHRDataset):
|
|
151
|
+
def __init__(self, args, split):
|
|
152
|
+
super().__init__(args, split)
|
|
153
|
+
self.data_path = os.path.join(args.data, f"{args.ehr}.flat.npy")
|
|
154
|
+
|
|
155
|
+
def __getitem__(self, idx):
|
|
156
|
+
# NOTE: Warning occurs when converting np.int16 read-only array into tensor, but ignoreable
|
|
157
|
+
data = self.data[self.keys[idx]]["fl"]
|
|
158
|
+
label = self.data[self.keys[idx]].attrs[self.pred_target]
|
|
159
|
+
return {
|
|
160
|
+
"input_ids": torch.IntTensor(data[0, :]),
|
|
161
|
+
"type_ids": torch.IntTensor(data[1, :]),
|
|
162
|
+
"dpe_ids": torch.IntTensor(data[2, :]),
|
|
163
|
+
"label": label,
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def main():
|
|
168
|
+
args = get_parser().parse_args()
|
|
169
|
+
dataset = HierarchicalEHRDataset(args, "train")
|
|
170
|
+
print(dataset.__getitem__(1))
|
|
171
|
+
dataset = FlattenEHRDataset(args, "train")
|
|
172
|
+
print(dataset.__getitem__(1))
|
|
173
|
+
pass
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
if __name__ == "__main__":
|
|
177
|
+
main()
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def col_name_add(x, cate_col):
|
|
6
|
+
if not (x == "nan" or x == pd.isnull(x)):
|
|
7
|
+
return cate_col + "_" + str(x)
|
|
8
|
+
else:
|
|
9
|
+
return x # return nan
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def q_cut(x, cuts):
|
|
13
|
+
unique_var = len(np.unique([i for i in x]))
|
|
14
|
+
nunique = len(pd.qcut(x, min(unique_var, cuts), duplicates="drop").cat.categories)
|
|
15
|
+
output = pd.qcut(x, min(unique_var, cuts), labels=range(1, min(nunique, cuts) + 1), duplicates="drop")
|
|
16
|
+
return output
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import random
|
|
3
|
+
import os
|
|
4
|
+
import logging
|
|
5
|
+
import h5pickle
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
def get_parser():
|
|
11
|
+
parser = argparse.ArgumentParser()
|
|
12
|
+
parser.add_argument(
|
|
13
|
+
"data", type=str, help="path to the .h5 file containing the data"
|
|
14
|
+
)
|
|
15
|
+
parser.add_argument(
|
|
16
|
+
"label", type=str, help="path to the .csv file containing the labels corresponded to the data"
|
|
17
|
+
)
|
|
18
|
+
parser.add_argument(
|
|
19
|
+
"--dest", default=".", type=str, metavar="DIR", help="output directory"
|
|
20
|
+
)
|
|
21
|
+
parser.add_argument(
|
|
22
|
+
"--prefix", default="", type=str, help="prefix for the output files"
|
|
23
|
+
)
|
|
24
|
+
parser.add_argument(
|
|
25
|
+
"--valid-percent",
|
|
26
|
+
default=0.1,
|
|
27
|
+
type=float,
|
|
28
|
+
metavar="D",
|
|
29
|
+
help="percentage of the data to use as validation and test set (between 0 and 0.5)"
|
|
30
|
+
)
|
|
31
|
+
parser.add_argument("--seed", default=42, type=int, metavar="N", help="random seed")
|
|
32
|
+
return parser
|
|
33
|
+
|
|
34
|
+
def main(args):
|
|
35
|
+
assert 0 <= args.valid_percent <= 0.5
|
|
36
|
+
|
|
37
|
+
data_path = os.path.realpath(args.data)
|
|
38
|
+
label_path = os.path.realpath(args.label)
|
|
39
|
+
prefix = args.prefix
|
|
40
|
+
if len(prefix) > 0 and not prefix.endswith("_"):
|
|
41
|
+
prefix += "_"
|
|
42
|
+
rand = random.Random(args.seed)
|
|
43
|
+
|
|
44
|
+
if not os.path.exists(args.dest):
|
|
45
|
+
os.makedirs(args.dest)
|
|
46
|
+
|
|
47
|
+
with (
|
|
48
|
+
open(os.path.join(args.dest, f"{prefix}train.tsv"), "w") as train_f,
|
|
49
|
+
open(os.path.join(args.dest, f"{prefix}valid.tsv"), "w") as valid_f,
|
|
50
|
+
open(os.path.join(args.dest, f"{prefix}test.tsv"), "w") as test_f
|
|
51
|
+
):
|
|
52
|
+
print(data_path, file=train_f)
|
|
53
|
+
print(label_path, file=train_f)
|
|
54
|
+
print(data_path, file=valid_f)
|
|
55
|
+
print(label_path, file=valid_f)
|
|
56
|
+
print(data_path, file=test_f)
|
|
57
|
+
print(label_path, file=test_f)
|
|
58
|
+
|
|
59
|
+
def write(subject_ids, dest, split):
|
|
60
|
+
for subject_id in tqdm(subject_ids, total=len(subject_ids), desc=split):
|
|
61
|
+
print(subject_id, file=dest)
|
|
62
|
+
|
|
63
|
+
data = h5pickle.File(data_path, "r")["ehr"]
|
|
64
|
+
subject_ids = list(data.keys())
|
|
65
|
+
|
|
66
|
+
rand.shuffle(subject_ids)
|
|
67
|
+
|
|
68
|
+
valid_len = int(len(subject_ids) * args.valid_percent)
|
|
69
|
+
test_len = int(len(subject_ids) * args.valid_percent)
|
|
70
|
+
train_len = len(subject_ids) - valid_len - test_len
|
|
71
|
+
|
|
72
|
+
train = subject_ids[:train_len]
|
|
73
|
+
valid = subject_ids[train_len: train_len + valid_len]
|
|
74
|
+
test = subject_ids[train_len + valid_len:]
|
|
75
|
+
|
|
76
|
+
write(train, train_f, split=f"{prefix}train")
|
|
77
|
+
write(valid, valid_f, split=f"{prefix}valid")
|
|
78
|
+
write(test, test_f, split=f"{prefix}test")
|
|
79
|
+
|
|
80
|
+
if __name__ == "__main__":
|
|
81
|
+
parser = get_parser()
|
|
82
|
+
args = parser.parse_args()
|
|
83
|
+
main(args)
|