genhpf 1.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genhpf/__init__.py +9 -0
- genhpf/configs/__init__.py +23 -0
- genhpf/configs/config.yaml +8 -0
- genhpf/configs/configs.py +240 -0
- genhpf/configs/constants.py +29 -0
- genhpf/configs/initialize.py +58 -0
- genhpf/configs/utils.py +29 -0
- genhpf/criterions/__init__.py +74 -0
- genhpf/criterions/binary_cross_entropy.py +114 -0
- genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
- genhpf/criterions/criterion.py +87 -0
- genhpf/criterions/cross_entropy.py +202 -0
- genhpf/criterions/multi_task_criterion.py +177 -0
- genhpf/criterions/simclr_criterion.py +84 -0
- genhpf/criterions/wav2vec2_criterion.py +130 -0
- genhpf/datasets/__init__.py +84 -0
- genhpf/datasets/dataset.py +109 -0
- genhpf/datasets/genhpf_dataset.py +451 -0
- genhpf/datasets/meds_dataset.py +232 -0
- genhpf/loggings/__init__.py +0 -0
- genhpf/loggings/meters.py +374 -0
- genhpf/loggings/metrics.py +155 -0
- genhpf/loggings/progress_bar.py +445 -0
- genhpf/models/__init__.py +73 -0
- genhpf/models/genhpf.py +244 -0
- genhpf/models/genhpf_mlm.py +64 -0
- genhpf/models/genhpf_predictor.py +73 -0
- genhpf/models/genhpf_simclr.py +58 -0
- genhpf/models/genhpf_wav2vec2.py +304 -0
- genhpf/modules/__init__.py +15 -0
- genhpf/modules/gather_layer.py +23 -0
- genhpf/modules/grad_multiply.py +12 -0
- genhpf/modules/gumbel_vector_quantizer.py +204 -0
- genhpf/modules/identity_layer.py +8 -0
- genhpf/modules/layer_norm.py +27 -0
- genhpf/modules/positional_encoding.py +24 -0
- genhpf/scripts/__init__.py +0 -0
- genhpf/scripts/preprocess/__init__.py +0 -0
- genhpf/scripts/preprocess/genhpf/README.md +75 -0
- genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
- genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
- genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
- genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
- genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
- genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
- genhpf/scripts/preprocess/genhpf/main.py +175 -0
- genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
- genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
- genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
- genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
- genhpf/scripts/preprocess/manifest.py +83 -0
- genhpf/scripts/preprocess/preprocess_meds.py +674 -0
- genhpf/scripts/test.py +264 -0
- genhpf/scripts/train.py +365 -0
- genhpf/trainer.py +370 -0
- genhpf/utils/checkpoint_utils.py +171 -0
- genhpf/utils/data_utils.py +130 -0
- genhpf/utils/distributed_utils.py +497 -0
- genhpf/utils/file_io.py +170 -0
- genhpf/utils/pdb.py +38 -0
- genhpf/utils/utils.py +204 -0
- genhpf-1.0.11.dist-info/LICENSE +21 -0
- genhpf-1.0.11.dist-info/METADATA +202 -0
- genhpf-1.0.11.dist-info/RECORD +67 -0
- genhpf-1.0.11.dist-info/WHEEL +5 -0
- genhpf-1.0.11.dist-info/entry_points.txt +6 -0
- genhpf-1.0.11.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,919 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
import pickle
|
|
4
|
+
import re
|
|
5
|
+
import shutil
|
|
6
|
+
import subprocess
|
|
7
|
+
import sys
|
|
8
|
+
from functools import reduce
|
|
9
|
+
from itertools import chain
|
|
10
|
+
from typing import List, Union
|
|
11
|
+
|
|
12
|
+
import h5py
|
|
13
|
+
import numpy as np
|
|
14
|
+
import pandas as pd
|
|
15
|
+
import pyspark.sql.functions as F
|
|
16
|
+
from pyspark.sql.types import ArrayType, IntegerType, StructField, StructType
|
|
17
|
+
from tqdm import tqdm
|
|
18
|
+
from transformers import AutoTokenizer
|
|
19
|
+
from utils import col_name_add, q_cut
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class EHR(object):
|
|
25
|
+
def __init__(self, cfg):
|
|
26
|
+
self.cfg = cfg
|
|
27
|
+
|
|
28
|
+
self.cache = cfg.cache
|
|
29
|
+
|
|
30
|
+
cache_dir = os.path.expanduser("~/.cache/ehr")
|
|
31
|
+
# cache_dir = self.cfg.dest
|
|
32
|
+
if not os.path.exists(cache_dir):
|
|
33
|
+
os.makedirs(cache_dir)
|
|
34
|
+
self.cache_dir = cache_dir
|
|
35
|
+
|
|
36
|
+
if self.cache:
|
|
37
|
+
logger.warning(
|
|
38
|
+
"--cache is set to True. Note that it forces to load cached"
|
|
39
|
+
f" data from {cache_dir},"
|
|
40
|
+
" which may ignore some arguments such as --first_icu, as well as task-related"
|
|
41
|
+
" arguments (--mortality, --los_3day, etc.)"
|
|
42
|
+
" If you want to avoid this, do not set --cache to True."
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
self.data_dir = cfg.data
|
|
46
|
+
self.ccs_path = cfg.ccs
|
|
47
|
+
self.gem_path = cfg.gem
|
|
48
|
+
self.ext = cfg.ext
|
|
49
|
+
|
|
50
|
+
self.max_event_size = cfg.max_event_size if cfg.max_event_size is not None else sys.maxsize
|
|
51
|
+
self.min_event_size = cfg.min_event_size if cfg.min_event_size is not None else 1
|
|
52
|
+
assert self.min_event_size > 0, (
|
|
53
|
+
"--min_event_size could not be negative or zero",
|
|
54
|
+
self.min_event_size,
|
|
55
|
+
)
|
|
56
|
+
assert self.min_event_size <= self.max_event_size, (
|
|
57
|
+
self.min_event_size,
|
|
58
|
+
self.max_event_size,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
self.max_event_token_len = cfg.max_event_token_len
|
|
62
|
+
self.max_patient_token_len = cfg.max_patient_token_len
|
|
63
|
+
|
|
64
|
+
self.max_age = cfg.max_age if cfg.max_age is not None else sys.maxsize
|
|
65
|
+
self.min_age = cfg.min_age if cfg.min_age is not None else 0
|
|
66
|
+
assert self.min_age <= self.max_age, (self.min_age, self.max_age)
|
|
67
|
+
|
|
68
|
+
self.obs_size = cfg.obs_size
|
|
69
|
+
self.gap_size = cfg.gap_size
|
|
70
|
+
self.pred_size = cfg.pred_size
|
|
71
|
+
self.long_term_pred_size = cfg.long_term_pred_size
|
|
72
|
+
|
|
73
|
+
self.first_icu = cfg.first_icu
|
|
74
|
+
|
|
75
|
+
# Emb_type / feature_select
|
|
76
|
+
self.emb_type = cfg.emb_type
|
|
77
|
+
self.feature = cfg.feature
|
|
78
|
+
self.bucket_num = cfg.bucket_num
|
|
79
|
+
|
|
80
|
+
# tasks
|
|
81
|
+
self.mortality = cfg.mortality
|
|
82
|
+
self.long_term_mortality = cfg.long_term_mortality
|
|
83
|
+
self.los_3day = cfg.los_3day
|
|
84
|
+
self.los_7day = cfg.los_7day
|
|
85
|
+
self.readmission = cfg.readmission
|
|
86
|
+
self.final_acuity = cfg.final_acuity
|
|
87
|
+
self.imminent_discharge = cfg.imminent_discharge
|
|
88
|
+
self.diagnosis = cfg.diagnosis
|
|
89
|
+
self.creatinine = cfg.creatinine
|
|
90
|
+
self.bilirubin = cfg.bilirubin
|
|
91
|
+
self.platelets = cfg.platelets
|
|
92
|
+
self.wbc = cfg.wbc
|
|
93
|
+
|
|
94
|
+
self.chunk_size = cfg.chunk_size
|
|
95
|
+
|
|
96
|
+
self.dest = cfg.dest
|
|
97
|
+
|
|
98
|
+
self.bins = cfg.bins
|
|
99
|
+
|
|
100
|
+
os.makedirs(os.path.join(self.dest, "data"), exist_ok=True)
|
|
101
|
+
os.makedirs(os.path.join(self.dest, "metadata"), exist_ok=True)
|
|
102
|
+
|
|
103
|
+
self.special_tokens_dict = dict()
|
|
104
|
+
self.max_special_tokens = 100
|
|
105
|
+
|
|
106
|
+
self.tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
|
|
107
|
+
|
|
108
|
+
if self.emb_type == "textbase":
|
|
109
|
+
self.cls_token_id = self.tokenizer.cls_token_id
|
|
110
|
+
self.sep_token_id = self.tokenizer.sep_token_id
|
|
111
|
+
|
|
112
|
+
elif self.emb_type == "codebase":
|
|
113
|
+
self.cls_token_id = 1
|
|
114
|
+
self.sep_token_id = 2
|
|
115
|
+
|
|
116
|
+
self.table_type_id = 1
|
|
117
|
+
self.column_type_id = 2
|
|
118
|
+
self.value_type_id = 3
|
|
119
|
+
self.timeint_type_id = 4
|
|
120
|
+
self.cls_type_id = 5
|
|
121
|
+
self.sep_type_id = 6
|
|
122
|
+
|
|
123
|
+
self.others_dpe_id = 0
|
|
124
|
+
|
|
125
|
+
self._icustay_fname = None
|
|
126
|
+
self._patient_fname = None
|
|
127
|
+
self._admission_fname = None
|
|
128
|
+
self._diagnosis_fname = None
|
|
129
|
+
|
|
130
|
+
self._icustay_key = None
|
|
131
|
+
self._hadm_key = None
|
|
132
|
+
|
|
133
|
+
@property
|
|
134
|
+
def icustay_fname(self):
|
|
135
|
+
return self._icustay_fname
|
|
136
|
+
|
|
137
|
+
@property
|
|
138
|
+
def patient_fname(self):
|
|
139
|
+
return self._patient_fname
|
|
140
|
+
|
|
141
|
+
@property
|
|
142
|
+
def admission_fname(self):
|
|
143
|
+
return self._admission_fname
|
|
144
|
+
|
|
145
|
+
@property
|
|
146
|
+
def diagnosis_fname(self):
|
|
147
|
+
return self._diagnosis_fname
|
|
148
|
+
|
|
149
|
+
@property
|
|
150
|
+
def icustay_key(self):
|
|
151
|
+
return self._icustay_key
|
|
152
|
+
|
|
153
|
+
@property
|
|
154
|
+
def hadm_key(self):
|
|
155
|
+
return self._hadm_key
|
|
156
|
+
|
|
157
|
+
@property
|
|
158
|
+
def patient_key(self):
|
|
159
|
+
return self._patient_key
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def determine_first_icu(self):
|
|
163
|
+
return self._determine_first_icu
|
|
164
|
+
|
|
165
|
+
@property
|
|
166
|
+
def num_special_tokens(self):
|
|
167
|
+
return len(self.special_tokens_dict)
|
|
168
|
+
|
|
169
|
+
def build_cohorts(self, icustays, cached=False):
|
|
170
|
+
if cached:
|
|
171
|
+
cohorts = self.load_from_cache(self.ehr_name + ".cohorts")
|
|
172
|
+
if cohorts is not None:
|
|
173
|
+
return cohorts
|
|
174
|
+
|
|
175
|
+
if not self.is_compatible(icustays):
|
|
176
|
+
raise AssertionError(
|
|
177
|
+
"{} do not have required columns to build cohorts.".format(self.icustay_fname)
|
|
178
|
+
+ " Please make sure that dataframe for icustays is compatible with other ehrs."
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
logger.info("Start building cohorts for {}".format(self.ehr_name))
|
|
182
|
+
logger.info("Emb_type {}, Feature {}".format(self.emb_type, self.feature))
|
|
183
|
+
obs_size = self.obs_size
|
|
184
|
+
gap_size = self.gap_size
|
|
185
|
+
|
|
186
|
+
icustays = icustays[icustays["LOS"] >= (obs_size + gap_size) / 24]
|
|
187
|
+
icustays = icustays[(self.min_age <= icustays["AGE"]) & (icustays["AGE"] <= self.max_age)]
|
|
188
|
+
|
|
189
|
+
# we define labels for the readmission task in this step
|
|
190
|
+
# since it requires to observe each next icustays,
|
|
191
|
+
# which would have been excluded in the final cohorts
|
|
192
|
+
icustays.sort_values([self.hadm_key, self.icustay_key], inplace=True)
|
|
193
|
+
|
|
194
|
+
if self.readmission:
|
|
195
|
+
icustays["readmission"] = 1
|
|
196
|
+
icustays.loc[
|
|
197
|
+
icustays.groupby(self.hadm_key)[self.determine_first_icu].idxmax(), "readmission"
|
|
198
|
+
] = 0
|
|
199
|
+
if self.first_icu:
|
|
200
|
+
icustays = icustays.loc[icustays.groupby(self.hadm_key)[self.determine_first_icu].idxmin()]
|
|
201
|
+
|
|
202
|
+
logger.info("cohorts have been built successfully. Loaded {} cohorts.".format(len(icustays)))
|
|
203
|
+
|
|
204
|
+
self.save_to_cache(icustays, self.ehr_name + ".cohorts")
|
|
205
|
+
|
|
206
|
+
return icustays
|
|
207
|
+
|
|
208
|
+
# TODO process specific tasks according to user choice?
|
|
209
|
+
def prepare_tasks(self, cohorts, spark, cached=False):
|
|
210
|
+
if cached:
|
|
211
|
+
labeled_cohorts = self.load_from_cache(self.ehr_name + ".cohorts.labeled")
|
|
212
|
+
if labeled_cohorts is not None:
|
|
213
|
+
return labeled_cohorts
|
|
214
|
+
else:
|
|
215
|
+
raise RuntimeError()
|
|
216
|
+
|
|
217
|
+
logger.info("Start labeling cohorts for predictive tasks.")
|
|
218
|
+
|
|
219
|
+
labeled_cohorts = cohorts[
|
|
220
|
+
[
|
|
221
|
+
self.hadm_key,
|
|
222
|
+
self.icustay_key,
|
|
223
|
+
self.patient_key,
|
|
224
|
+
"readmission",
|
|
225
|
+
"LOS",
|
|
226
|
+
"INTIME",
|
|
227
|
+
"OUTTIME",
|
|
228
|
+
"DISCHTIME",
|
|
229
|
+
"IN_ICU_MORTALITY",
|
|
230
|
+
"HOS_DISCHARGE_LOCATION",
|
|
231
|
+
]
|
|
232
|
+
].copy()
|
|
233
|
+
|
|
234
|
+
# mortality prediction
|
|
235
|
+
# if the discharge location of an icustay is 'Death'
|
|
236
|
+
# & intime + obs_size + gap_size <= dischtime <= intime + obs_size + pred_size
|
|
237
|
+
# it is assigned positive label on the mortality prediction
|
|
238
|
+
|
|
239
|
+
if self.mortality:
|
|
240
|
+
labeled_cohorts["mortality"] = (
|
|
241
|
+
(
|
|
242
|
+
(labeled_cohorts["IN_ICU_MORTALITY"] == "Death")
|
|
243
|
+
| (labeled_cohorts["HOS_DISCHARGE_LOCATION"] == "Death")
|
|
244
|
+
)
|
|
245
|
+
& (self.obs_size * 60 + self.gap_size * 60 < labeled_cohorts["DISCHTIME"])
|
|
246
|
+
& (labeled_cohorts["DISCHTIME"] <= self.obs_size * 60 + self.pred_size * 60)
|
|
247
|
+
).astype(int)
|
|
248
|
+
|
|
249
|
+
if self.long_term_mortality:
|
|
250
|
+
labeled_cohorts["long_term_mortality"] = (
|
|
251
|
+
(
|
|
252
|
+
(labeled_cohorts["IN_ICU_MORTALITY"] == "Death")
|
|
253
|
+
| (labeled_cohorts["HOS_DISCHARGE_LOCATION"] == "Death")
|
|
254
|
+
)
|
|
255
|
+
& (self.obs_size * 60 + self.gap_size * 60 < labeled_cohorts["DISCHTIME"])
|
|
256
|
+
& (labeled_cohorts["DISCHTIME"] <= self.obs_size * 60 + self.long_term_pred_size * 60)
|
|
257
|
+
).astype(int)
|
|
258
|
+
|
|
259
|
+
if self.los_3day:
|
|
260
|
+
labeled_cohorts["los_3day"] = (labeled_cohorts["LOS"] > 3).astype(int)
|
|
261
|
+
if self.los_7day:
|
|
262
|
+
labeled_cohorts["los_7day"] = (labeled_cohorts["LOS"] > 7).astype(int)
|
|
263
|
+
|
|
264
|
+
if self.final_acuity or self.imminent_discharge:
|
|
265
|
+
# if the discharge of 'Death' occurs in icu or hospital
|
|
266
|
+
# we retain these cases for the imminent discharge task
|
|
267
|
+
labeled_cohorts["IN_HOSPITAL_MORTALITY"] = (
|
|
268
|
+
(~labeled_cohorts["IN_ICU_MORTALITY"])
|
|
269
|
+
& (labeled_cohorts["HOS_DISCHARGE_LOCATION"] == "Death")
|
|
270
|
+
).astype(int)
|
|
271
|
+
|
|
272
|
+
if self.final_acuity:
|
|
273
|
+
# define final acuity prediction task
|
|
274
|
+
labeled_cohorts["final_acuity"] = labeled_cohorts["HOS_DISCHARGE_LOCATION"]
|
|
275
|
+
labeled_cohorts.loc[
|
|
276
|
+
labeled_cohorts["IN_ICU_MORTALITY"] == 1, "final_acuity"
|
|
277
|
+
] = "IN_ICU_MORTALITY"
|
|
278
|
+
labeled_cohorts.loc[
|
|
279
|
+
labeled_cohorts["IN_HOSPITAL_MORTALITY"] == 1, "final_acuity"
|
|
280
|
+
] = "IN_HOSPITAL_MORTALITY"
|
|
281
|
+
|
|
282
|
+
with open(os.path.join(self.dest, "metadata/final_acuity_classes.tsv"), "w") as f:
|
|
283
|
+
for i, cat in enumerate(
|
|
284
|
+
labeled_cohorts["final_acuity"].astype("category").cat.categories
|
|
285
|
+
):
|
|
286
|
+
print("{}\t{}".format(i, cat), file=f)
|
|
287
|
+
labeled_cohorts["final_acuity"] = labeled_cohorts["final_acuity"].astype("category").cat.codes
|
|
288
|
+
# Replace -1 with NaN
|
|
289
|
+
labeled_cohorts.loc[labeled_cohorts["final_acuity"] == -1, "final_acuity"] = np.nan
|
|
290
|
+
if self.imminent_discharge:
|
|
291
|
+
# define imminent discharge prediction task
|
|
292
|
+
is_discharged = (self.obs_size * 60 + self.gap_size * 60 <= labeled_cohorts["DISCHTIME"]) & (
|
|
293
|
+
labeled_cohorts["DISCHTIME"] <= self.obs_size * 60 + self.pred_size * 60
|
|
294
|
+
)
|
|
295
|
+
labeled_cohorts.loc[is_discharged, "imminent_discharge"] = labeled_cohorts.loc[
|
|
296
|
+
is_discharged, "HOS_DISCHARGE_LOCATION"
|
|
297
|
+
]
|
|
298
|
+
labeled_cohorts.loc[
|
|
299
|
+
is_discharged
|
|
300
|
+
& (
|
|
301
|
+
(labeled_cohorts["IN_ICU_MORTALITY"] == 1)
|
|
302
|
+
| (labeled_cohorts["IN_HOSPITAL_MORTALITY"] == 1)
|
|
303
|
+
),
|
|
304
|
+
"imminent_discharge",
|
|
305
|
+
] = "Death"
|
|
306
|
+
labeled_cohorts.loc[~is_discharged, "imminent_discharge"] = "No Discharge"
|
|
307
|
+
|
|
308
|
+
with open(os.path.join(self.dest, "metadata/imminent_discharge_classes.tsv"), "w") as f:
|
|
309
|
+
for i, cat in enumerate(
|
|
310
|
+
labeled_cohorts["imminent_discharge"].astype("category").cat.categories
|
|
311
|
+
):
|
|
312
|
+
print("{}\t{}".format(i, cat), file=f)
|
|
313
|
+
labeled_cohorts["imminent_discharge"] = (
|
|
314
|
+
labeled_cohorts["imminent_discharge"].astype("category").cat.codes
|
|
315
|
+
)
|
|
316
|
+
# Replace -1 with NaN
|
|
317
|
+
labeled_cohorts.loc[
|
|
318
|
+
labeled_cohorts["imminent_discharge"] == -1, "imminent_discharge"
|
|
319
|
+
] = np.nan
|
|
320
|
+
labeled_cohorts = labeled_cohorts.drop(columns=["IN_HOSPITAL_MORTALITY"])
|
|
321
|
+
|
|
322
|
+
# clean up unnecessary columns
|
|
323
|
+
labeled_cohorts = labeled_cohorts.drop(
|
|
324
|
+
columns=["LOS", "IN_ICU_MORTALITY", "DISCHTIME", "HOS_DISCHARGE_LOCATION"]
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
self.save_to_cache(labeled_cohorts, self.ehr_name + ".cohorts.labeled")
|
|
328
|
+
|
|
329
|
+
logger.info("Done preparing tasks except for diagnosis prediction.")
|
|
330
|
+
|
|
331
|
+
return labeled_cohorts
|
|
332
|
+
|
|
333
|
+
def process_tables(self, cohorts, spark):
|
|
334
|
+
# in: cohorts, sparksession
|
|
335
|
+
# out: Spark DataFrame with (stay_id, time offset, inp, type, dpe)
|
|
336
|
+
if isinstance(cohorts, pd.DataFrame):
|
|
337
|
+
cohorts = cohorts[[self.hadm_key, self.icustay_key, "INTIME", "OUTTIME"]]
|
|
338
|
+
logger.info("Start Preprocessing Tables, Cohort Numbers: {}".format(len(cohorts)))
|
|
339
|
+
cohorts = spark.createDataFrame(cohorts)
|
|
340
|
+
print("Converted Cohort to Pyspark DataFrame")
|
|
341
|
+
else:
|
|
342
|
+
logger.info("Start Preprocessing Tables")
|
|
343
|
+
|
|
344
|
+
events_dfs = []
|
|
345
|
+
for table_index, table in enumerate(self.tables):
|
|
346
|
+
fname = table["fname"]
|
|
347
|
+
table_name = fname.split("/")[-1][: -len(self.ext)]
|
|
348
|
+
|
|
349
|
+
timestamp_key = table["timestamp"]
|
|
350
|
+
excludes = table["exclude"]
|
|
351
|
+
obs_size = self.obs_size
|
|
352
|
+
logger.info("{} in progress.".format(fname))
|
|
353
|
+
|
|
354
|
+
code_to_descriptions = None
|
|
355
|
+
if "code" in table:
|
|
356
|
+
code_to_descriptions = {
|
|
357
|
+
k: pd.read_csv(os.path.join(self.data_dir, v))
|
|
358
|
+
for k, v in zip(table["code"], table["desc"])
|
|
359
|
+
}
|
|
360
|
+
code_to_descriptions = {
|
|
361
|
+
k: dict(zip(v[k], v[d_k]))
|
|
362
|
+
for (k, v), d_k in zip(code_to_descriptions.items(), table["desc_key"])
|
|
363
|
+
}
|
|
364
|
+
|
|
365
|
+
infer_icustay_from_hadm_key = False
|
|
366
|
+
|
|
367
|
+
events = spark.read.csv(os.path.join(self.data_dir, fname), header=True)
|
|
368
|
+
if self.icustay_key not in events.columns:
|
|
369
|
+
infer_icustay_from_hadm_key = True
|
|
370
|
+
if self.hadm_key not in events.columns:
|
|
371
|
+
raise AssertionError(
|
|
372
|
+
"{} doesn't have one of these columns: {}".format(
|
|
373
|
+
fname, [self.icustay_key, self.hadm_key]
|
|
374
|
+
)
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
events = events.drop(*excludes)
|
|
378
|
+
if table["timeoffsetunit"] == "abs":
|
|
379
|
+
events = events.withColumn(timestamp_key, F.to_timestamp(timestamp_key))
|
|
380
|
+
|
|
381
|
+
if infer_icustay_from_hadm_key:
|
|
382
|
+
events = events.join(
|
|
383
|
+
cohorts.select(self.hadm_key, self.icustay_key, "INTIME", "OUTTIME"),
|
|
384
|
+
on=self.hadm_key,
|
|
385
|
+
how="inner",
|
|
386
|
+
)
|
|
387
|
+
if table["timeoffsetunit"] == "abs":
|
|
388
|
+
events = (
|
|
389
|
+
events.withColumn(
|
|
390
|
+
"TEMP_TIME",
|
|
391
|
+
F.round((F.col(timestamp_key).cast("long") - F.col("INTIME").cast("long")) / 60),
|
|
392
|
+
)
|
|
393
|
+
.filter(F.col("TEMP_TIME") >= 0)
|
|
394
|
+
.filter(F.col("TEMP_TIME") <= F.col("OUTTIME"))
|
|
395
|
+
.drop("TEMP_TIME")
|
|
396
|
+
)
|
|
397
|
+
else:
|
|
398
|
+
# All tables in eICU has icustay_key -> no need to handle
|
|
399
|
+
raise NotImplementedError()
|
|
400
|
+
events = events.join(cohorts.select(self.icustay_key), on=self.icustay_key, how="leftsemi")
|
|
401
|
+
|
|
402
|
+
else:
|
|
403
|
+
events = events.join(
|
|
404
|
+
cohorts.select(self.icustay_key, "INTIME", "OUTTIME"), on=self.icustay_key, how="inner"
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
if table["timeoffsetunit"] == "abs":
|
|
408
|
+
events = events.withColumn(
|
|
409
|
+
"TIME", F.round((F.col(timestamp_key).cast("long") - F.col("INTIME").cast("long")) / 60)
|
|
410
|
+
)
|
|
411
|
+
events = events.drop(timestamp_key)
|
|
412
|
+
elif table["timeoffsetunit"] == "min":
|
|
413
|
+
events = events.withColumn("TIME", F.col(timestamp_key).cast("int"))
|
|
414
|
+
events = events.drop(timestamp_key)
|
|
415
|
+
else:
|
|
416
|
+
raise NotImplementedError()
|
|
417
|
+
|
|
418
|
+
events = events.filter(F.col("TIME") >= 0).filter(F.col("TIME") <= obs_size * 60)
|
|
419
|
+
events = events.drop("INTIME", "OUTTIME", self.hadm_key)
|
|
420
|
+
|
|
421
|
+
if code_to_descriptions:
|
|
422
|
+
for col in code_to_descriptions.keys():
|
|
423
|
+
mapping_expr = F.create_map([F.lit(x) for x in chain(*code_to_descriptions[col].items())])
|
|
424
|
+
events = events.withColumn(col, mapping_expr[F.col(col)])
|
|
425
|
+
|
|
426
|
+
if self.emb_type == "codebase":
|
|
427
|
+
print("codebase pre-process starts")
|
|
428
|
+
events = events.toPandas()
|
|
429
|
+
print("Converted events Pyspark DataFrame to Pandas DataFrame")
|
|
430
|
+
|
|
431
|
+
code_col = table["code_feat"][0]
|
|
432
|
+
dtype_code = events[code_col].dtype
|
|
433
|
+
|
|
434
|
+
# Numericl feature bucketize
|
|
435
|
+
for numeric_col in table["numeric_feat"]:
|
|
436
|
+
if numeric_col in events.columns:
|
|
437
|
+
# numeric / not_numeric classifying
|
|
438
|
+
numeric = events[pd.to_numeric(events[numeric_col], errors="coerce").notnull()]
|
|
439
|
+
numeric = numeric.astype({numeric_col: "float"})
|
|
440
|
+
not_numeric = events[pd.to_numeric(events[numeric_col], errors="coerce").isnull()]
|
|
441
|
+
|
|
442
|
+
# buckettize
|
|
443
|
+
numeric.loc[:, numeric_col] = numeric.groupby(code_col)[numeric_col].transform(
|
|
444
|
+
lambda x: x.rank(method="dense")
|
|
445
|
+
)
|
|
446
|
+
numeric.loc[:, numeric_col] = numeric.groupby(code_col)[numeric_col].transform(
|
|
447
|
+
lambda x: q_cut(x, self.bucket_num)
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
numeric[numeric_col] = (
|
|
451
|
+
"B_" + numeric[code_col].astype("str") + "_" + numeric[numeric_col].astype("str")
|
|
452
|
+
)
|
|
453
|
+
|
|
454
|
+
events = pd.concat([numeric, not_numeric], axis=0)
|
|
455
|
+
|
|
456
|
+
# Categorical feature categorize
|
|
457
|
+
for cate_col in table["categorical_feat"]:
|
|
458
|
+
if cate_col in events.columns:
|
|
459
|
+
events.loc[:, cate_col] = events[cate_col].map(lambda x: col_name_add(x, cate_col))
|
|
460
|
+
|
|
461
|
+
events = events.astype({code_col: dtype_code})
|
|
462
|
+
|
|
463
|
+
table_feature_unique = []
|
|
464
|
+
for col in events.columns:
|
|
465
|
+
if col in [self.icustay_key, self.hadm_key, self.patient_key, "TIME"]:
|
|
466
|
+
continue
|
|
467
|
+
col_unique = list(events[col].unique())
|
|
468
|
+
table_feature_unique.extend(col_unique)
|
|
469
|
+
|
|
470
|
+
table_feature_unique = list(set(table_feature_unique))
|
|
471
|
+
|
|
472
|
+
if len(events_dfs) == 0:
|
|
473
|
+
max_idx = 3 + len(self.tables)
|
|
474
|
+
self.table_feature_dict = {k: idx + max_idx for idx, k in enumerate(table_feature_unique)}
|
|
475
|
+
else:
|
|
476
|
+
max_idx = max(self.table_feature_dict.values())
|
|
477
|
+
table_feature_unique = [
|
|
478
|
+
k for k in table_feature_unique if k not in self.table_feature_dict.keys()
|
|
479
|
+
]
|
|
480
|
+
self.table_feature_dict.update(
|
|
481
|
+
{k: idx + max_idx for idx, k in enumerate(table_feature_unique)}
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
encoded_table_name = ([3 + table_index], [self.table_type_id], [0])
|
|
485
|
+
encoded_cols = {
|
|
486
|
+
k: ([max(self.table_feature_dict.values()) + 1 + i], [self.column_type_id], [0])
|
|
487
|
+
for i, k in enumerate(events.columns)
|
|
488
|
+
if k not in [self.icustay_key, self.hadm_key, self.patient_key, "TIME"]
|
|
489
|
+
}
|
|
490
|
+
for i, k in enumerate(events.columns):
|
|
491
|
+
if k not in [self.icustay_key, self.hadm_key, self.patient_key, "TIME"]:
|
|
492
|
+
self.table_feature_dict[k] = max(self.table_feature_dict.values()) + 1 + i
|
|
493
|
+
|
|
494
|
+
print("length of codebook = ", len(self.table_feature_dict))
|
|
495
|
+
|
|
496
|
+
events = events.fillna(" ")
|
|
497
|
+
events = spark.createDataFrame(events)
|
|
498
|
+
print("Converted Events DataFrame to Pyspark DataFrame")
|
|
499
|
+
|
|
500
|
+
def process_unit(feature, type_id):
|
|
501
|
+
input_ids = [self.table_feature_dict[feature]]
|
|
502
|
+
types = [type_id]
|
|
503
|
+
dpes = [0]
|
|
504
|
+
return input_ids, types, dpes
|
|
505
|
+
|
|
506
|
+
elif self.emb_type == "textbase":
|
|
507
|
+
print("textbase pre-process starts")
|
|
508
|
+
|
|
509
|
+
def process_unit(text, type_id):
|
|
510
|
+
# Given (table_name|col|val), generate ([inp], [type], [dpe])
|
|
511
|
+
text = re.sub(r"\d*\.\d+", lambda x: str(round(float(x.group(0)), 4)), str(text))
|
|
512
|
+
number_groups = [g for g in re.finditer(r"([0-9]+([.][0-9]*)?|[0-9]+|\.+)", text)]
|
|
513
|
+
text = re.sub(r"([0-9\.])", r" \1 ", text)
|
|
514
|
+
input_ids = self.tokenizer.encode(text, add_special_tokens=False)
|
|
515
|
+
types = [type_id] * len(input_ids)
|
|
516
|
+
|
|
517
|
+
def get_dpe(tokens, number_groups):
|
|
518
|
+
number_ids = [121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 119]
|
|
519
|
+
numbers = [i for i, j in enumerate(tokens) if j in number_ids]
|
|
520
|
+
numbers_cnt = 0
|
|
521
|
+
data_dpe = [0] * len(tokens)
|
|
522
|
+
for group in number_groups:
|
|
523
|
+
if group[0] == "." * len(group[0]):
|
|
524
|
+
numbers_cnt += len(group[0])
|
|
525
|
+
continue
|
|
526
|
+
|
|
527
|
+
start = numbers[numbers_cnt]
|
|
528
|
+
end = numbers[numbers_cnt + len(group[0]) - 1] + 1
|
|
529
|
+
corresponding_numbers = tokens[start:end]
|
|
530
|
+
digits = [i for i, j in enumerate(corresponding_numbers) if j == 119]
|
|
531
|
+
|
|
532
|
+
# Case Integer
|
|
533
|
+
if len(digits) == 0:
|
|
534
|
+
data_dpe[start:end] = list(range(len(group[0]) + 5, 5, -1))
|
|
535
|
+
# Case Float
|
|
536
|
+
elif len(digits) == 1:
|
|
537
|
+
digit_idx = len(group[0]) - digits[0]
|
|
538
|
+
data_dpe[start:end] = list(
|
|
539
|
+
range(len(group[0]) + 5 - digit_idx, 5 - digit_idx, -1)
|
|
540
|
+
)
|
|
541
|
+
else:
|
|
542
|
+
logger.warn(f"{data_dpe[start:end]} has irregular numerical formats")
|
|
543
|
+
|
|
544
|
+
numbers_cnt += len(group[0])
|
|
545
|
+
return data_dpe
|
|
546
|
+
|
|
547
|
+
dpes = get_dpe(input_ids, number_groups)
|
|
548
|
+
return input_ids, types, dpes
|
|
549
|
+
|
|
550
|
+
encoded_table_name = process_unit(table_name, self.table_type_id) # table name
|
|
551
|
+
encoded_cols = {k: process_unit(k, self.column_type_id) for k in events.columns} # table cols
|
|
552
|
+
|
|
553
|
+
schema = StructType(
|
|
554
|
+
[
|
|
555
|
+
StructField("INPUTS", ArrayType(IntegerType()), False),
|
|
556
|
+
StructField("TYPES", ArrayType(IntegerType()), False),
|
|
557
|
+
StructField("DPES", ArrayType(IntegerType()), False),
|
|
558
|
+
]
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
# encoded_table_name -> tuple(Input, type, dpe)
|
|
562
|
+
# encoded_cols -> dict(col_name: input,type, dpe)
|
|
563
|
+
def process_row(encoded_table_name, encoded_cols):
|
|
564
|
+
def _process_row(row):
|
|
565
|
+
"""
|
|
566
|
+
input: row (cols: icustay_id, timestamp, ...)
|
|
567
|
+
output: (input, type, dpe)
|
|
568
|
+
"""
|
|
569
|
+
row = row.asDict()
|
|
570
|
+
# Should INITIALIZE with blank arrays to prevent corruption in Pyspark... Why??
|
|
571
|
+
input_ids, types, dpes = [], [], []
|
|
572
|
+
input_ids += encoded_table_name[0]
|
|
573
|
+
types += encoded_table_name[1]
|
|
574
|
+
dpes += encoded_table_name[2]
|
|
575
|
+
encoded_table_name
|
|
576
|
+
for col, val in row.items():
|
|
577
|
+
if col in [self.icustay_key, "TIME"] or val is None or val == " ":
|
|
578
|
+
continue
|
|
579
|
+
encoded_col = encoded_cols[col]
|
|
580
|
+
encoded_val = process_unit(val, self.value_type_id)
|
|
581
|
+
|
|
582
|
+
# -> values code화
|
|
583
|
+
|
|
584
|
+
if (
|
|
585
|
+
len(input_ids) + len(encoded_col[0]) + len(encoded_val[0]) + 2
|
|
586
|
+
<= self.max_event_token_len
|
|
587
|
+
):
|
|
588
|
+
input_ids += encoded_col[0] + encoded_val[0]
|
|
589
|
+
types += encoded_col[1] + encoded_val[1]
|
|
590
|
+
dpes += encoded_col[2] + encoded_val[2]
|
|
591
|
+
else:
|
|
592
|
+
break
|
|
593
|
+
return input_ids, types, dpes
|
|
594
|
+
|
|
595
|
+
return F.udf(_process_row, returnType=schema)
|
|
596
|
+
|
|
597
|
+
events = (
|
|
598
|
+
events.withColumn(
|
|
599
|
+
"tmp", process_row(encoded_table_name, encoded_cols)(F.struct(*events.columns))
|
|
600
|
+
)
|
|
601
|
+
.withColumn("INPUTS", F.col("tmp.INPUTS"))
|
|
602
|
+
.withColumn("TYPES", F.col("tmp.TYPES"))
|
|
603
|
+
.withColumn("DPES", F.col("tmp.DPES"))
|
|
604
|
+
.select(self.icustay_key, "TIME", "INPUTS", "TYPES", "DPES") # column 을 얘네만 남김
|
|
605
|
+
) # event process # -> INPUTS / TYPES / DPES /
|
|
606
|
+
|
|
607
|
+
events_dfs.append(events)
|
|
608
|
+
|
|
609
|
+
if self.emb_type == "codebase":
|
|
610
|
+
os.makedirs(os.path.join(self.cache_dir, self.ehr_name), exist_ok=True)
|
|
611
|
+
with open(
|
|
612
|
+
os.path.join(self.cache_dir, self.ehr_name, f"codebase_code2idx_{self.feature}.pkl"), "wb"
|
|
613
|
+
) as f:
|
|
614
|
+
pickle.dump(self.table_feature_dict, f)
|
|
615
|
+
|
|
616
|
+
return reduce(lambda x, y: x.union(y), events_dfs)
|
|
617
|
+
|
|
618
|
+
def make_input(self, cohorts, events, spark):
|
|
619
|
+
@F.pandas_udf(returnType="TIME int", functionType=F.PandasUDFType.GROUPED_MAP)
|
|
620
|
+
def _make_input(events):
|
|
621
|
+
# Actually, this function does not have to return anything.
|
|
622
|
+
# However, return something(TIME) is required to satisfy the PySpark requirements.
|
|
623
|
+
df = events.sort_values("TIME")
|
|
624
|
+
flatten_cut_idx = -1
|
|
625
|
+
# Consider SEP
|
|
626
|
+
flatten_lens = np.cumsum(df["INPUTS"].str.len() + 1).values
|
|
627
|
+
event_length = len(df)
|
|
628
|
+
|
|
629
|
+
if flatten_lens[-1] > self.max_patient_token_len - 1:
|
|
630
|
+
# Consider CLS token at first of the flatten input
|
|
631
|
+
flatten_cut_idx = np.searchsorted(
|
|
632
|
+
flatten_lens, flatten_lens[-1] - self.max_patient_token_len + 1
|
|
633
|
+
)
|
|
634
|
+
flatten_lens = (flatten_lens - flatten_lens[flatten_cut_idx])[flatten_cut_idx + 1 :]
|
|
635
|
+
event_length = len(flatten_lens)
|
|
636
|
+
|
|
637
|
+
# Event length should not be longer than max_event_size
|
|
638
|
+
event_length = min(event_length, self.max_event_size)
|
|
639
|
+
df = df.iloc[-event_length:]
|
|
640
|
+
|
|
641
|
+
if len(df) <= self.min_event_size:
|
|
642
|
+
return events["TIME"].to_frame()
|
|
643
|
+
|
|
644
|
+
def make_hi(cls_id, sep_id, iterable):
|
|
645
|
+
return [[cls_id] + list(i) + [sep_id] for i in iterable]
|
|
646
|
+
|
|
647
|
+
def make_fl(cls_id, sep_id, iterable):
|
|
648
|
+
return [cls_id] + list(chain(*[list(i) + [sep_id] for i in iterable]))
|
|
649
|
+
|
|
650
|
+
hi_input = make_hi(self.cls_token_id, self.sep_token_id, df["INPUTS"])
|
|
651
|
+
hi_type = make_hi(self.cls_type_id, self.sep_type_id, df["TYPES"])
|
|
652
|
+
hi_dpe = make_hi(self.others_dpe_id, self.others_dpe_id, df["DPES"])
|
|
653
|
+
|
|
654
|
+
fl_input = make_fl(self.cls_token_id, self.sep_token_id, df["INPUTS"])
|
|
655
|
+
fl_type = make_fl(self.cls_type_id, self.sep_type_id, df["TYPES"])
|
|
656
|
+
fl_dpe = make_fl(self.others_dpe_id, self.others_dpe_id, df["DPES"])
|
|
657
|
+
|
|
658
|
+
assert len(hi_input) <= self.max_event_size, hi_input
|
|
659
|
+
assert all([len(i) <= self.max_event_token_len for i in hi_input]), hi_input
|
|
660
|
+
assert len(fl_input) <= self.max_patient_token_len, fl_input
|
|
661
|
+
|
|
662
|
+
# Add padding to save as numpy array
|
|
663
|
+
hi_input = np.array(
|
|
664
|
+
[np.pad(i, (0, self.max_event_token_len - len(i)), mode="constant") for i in hi_input]
|
|
665
|
+
)
|
|
666
|
+
hi_type = np.array(
|
|
667
|
+
[np.pad(i, (0, self.max_event_token_len - len(i)), mode="constant") for i in hi_type]
|
|
668
|
+
)
|
|
669
|
+
hi_dpe = np.array(
|
|
670
|
+
[np.pad(i, (0, self.max_event_token_len - len(i)), mode="constant") for i in hi_dpe]
|
|
671
|
+
)
|
|
672
|
+
|
|
673
|
+
fl_input = np.pad(fl_input, (0, self.max_patient_token_len - len(fl_input)), mode="constant")
|
|
674
|
+
fl_type = np.pad(fl_type, (0, self.max_patient_token_len - len(fl_type)), mode="constant")
|
|
675
|
+
fl_dpe = np.pad(fl_dpe, (0, self.max_patient_token_len - len(fl_dpe)), mode="constant")
|
|
676
|
+
|
|
677
|
+
stay_id = df[self.icustay_key].values[0]
|
|
678
|
+
# Create caches (cannot write to hdf5 directly with pyspark)
|
|
679
|
+
data = {
|
|
680
|
+
"hi": np.stack([hi_input, hi_type, hi_dpe], axis=1).astype(np.int16),
|
|
681
|
+
"fl": np.stack([fl_input, fl_type, fl_dpe], axis=0).astype(np.int16),
|
|
682
|
+
"time": df["TIME"].values,
|
|
683
|
+
}
|
|
684
|
+
with open(
|
|
685
|
+
os.path.join(self.cache_dir, self.ehr_name, self.emb_type, self.feature, f"{stay_id}.pkl"),
|
|
686
|
+
"wb",
|
|
687
|
+
) as f:
|
|
688
|
+
pickle.dump(data, f)
|
|
689
|
+
return events["TIME"].to_frame()
|
|
690
|
+
|
|
691
|
+
os.makedirs(os.path.join(self.cache_dir, self.ehr_name, self.emb_type, self.feature), exist_ok=True)
|
|
692
|
+
|
|
693
|
+
events.groupBy(self.icustay_key).apply(_make_input).write.mode("overwrite").format("noop").save()
|
|
694
|
+
|
|
695
|
+
logger.info("Finish Data Preprocessing. Start to write to hdf5")
|
|
696
|
+
|
|
697
|
+
f = h5py.File(os.path.join(self.dest, "data/data.h5"), "w")
|
|
698
|
+
ehr_g = f.create_group("ehr")
|
|
699
|
+
|
|
700
|
+
active_stay_ids = []
|
|
701
|
+
|
|
702
|
+
for stay_id_file in tqdm(
|
|
703
|
+
os.listdir(os.path.join(self.cache_dir, self.ehr_name, self.emb_type, self.feature))
|
|
704
|
+
): # 45080 / 45080
|
|
705
|
+
stay_id = stay_id_file.split(".")[0]
|
|
706
|
+
with open(
|
|
707
|
+
os.path.join(self.cache_dir, self.ehr_name, self.emb_type, self.feature, stay_id_file), "rb"
|
|
708
|
+
) as f:
|
|
709
|
+
data = pickle.load(f)
|
|
710
|
+
stay_g = ehr_g.create_group(str(stay_id))
|
|
711
|
+
stay_g.create_dataset("hi", data=data["hi"], dtype="i2", compression="lzf", shuffle=True)
|
|
712
|
+
stay_g.create_dataset("fl", data=data["fl"], dtype="i2", compression="lzf", shuffle=True)
|
|
713
|
+
stay_g.create_dataset("time", data=data["time"], dtype="i")
|
|
714
|
+
active_stay_ids.append(int(stay_id))
|
|
715
|
+
|
|
716
|
+
shutil.rmtree(
|
|
717
|
+
os.path.join(self.cache_dir, self.ehr_name, self.emb_type, self.feature), ignore_errors=True
|
|
718
|
+
)
|
|
719
|
+
# Drop patients with few events
|
|
720
|
+
|
|
721
|
+
if not isinstance(cohorts, pd.DataFrame):
|
|
722
|
+
cohorts = cohorts.toPandas()
|
|
723
|
+
print(cohorts)
|
|
724
|
+
|
|
725
|
+
logger.info(
|
|
726
|
+
"Total {} patients in the cohort are skipped due to few events".format(
|
|
727
|
+
len(cohorts) - len(active_stay_ids)
|
|
728
|
+
)
|
|
729
|
+
)
|
|
730
|
+
cohorts = cohorts[cohorts[self.icustay_key].isin(active_stay_ids)]
|
|
731
|
+
|
|
732
|
+
# Should consider pat_id for split
|
|
733
|
+
# for seed in self.seed:
|
|
734
|
+
# shuffled = cohorts.groupby(
|
|
735
|
+
# self.patient_key
|
|
736
|
+
# )[self.patient_key].count().sample(frac=1, random_state=seed)
|
|
737
|
+
# cum_len = shuffled.cumsum()
|
|
738
|
+
|
|
739
|
+
# cohorts.loc[cohorts[self.patient_key].isin(
|
|
740
|
+
# shuffled[cum_len < int(sum(shuffled)*self.valid_percent)].index), f'split_{seed}'] = 'test'
|
|
741
|
+
# cohorts.loc[cohorts[self.patient_key].isin(
|
|
742
|
+
# shuffled[(cum_len >= int(sum(shuffled)*self.valid_percent))
|
|
743
|
+
# & (cum_len < int(sum(shuffled)*2*self.valid_percent))].index), f'split_{seed}'] = 'valid'
|
|
744
|
+
# cohorts.loc[
|
|
745
|
+
# cohorts[self.patient_key].isin(
|
|
746
|
+
# shuffled[cum_len >= int(sum(shuffled)*2*self.valid_percent)].index
|
|
747
|
+
# ), f'split_{seed}'
|
|
748
|
+
# ] = 'train'
|
|
749
|
+
|
|
750
|
+
# select columns stay_id and tasks
|
|
751
|
+
task_columns = [
|
|
752
|
+
"mortality",
|
|
753
|
+
"long_term_mortality",
|
|
754
|
+
"los_3day",
|
|
755
|
+
"los_7day",
|
|
756
|
+
"readmission",
|
|
757
|
+
"final_acuity",
|
|
758
|
+
"imminent_discharge",
|
|
759
|
+
]
|
|
760
|
+
selected_columns = [self.icustay_key] + [col for col in task_columns if col in cohorts.columns]
|
|
761
|
+
cohorts = cohorts[selected_columns]
|
|
762
|
+
cohorts.rename(columns={self.icustay_key: "stay_id"}, inplace=True)
|
|
763
|
+
|
|
764
|
+
cohorts.to_csv(os.path.join(self.dest, "data/label.csv"), index=False)
|
|
765
|
+
|
|
766
|
+
# Record corhots df to hdf5
|
|
767
|
+
for _, row in cohorts.iterrows():
|
|
768
|
+
group = ehr_g[str(row["stay_id"])]
|
|
769
|
+
for col in cohorts.columns:
|
|
770
|
+
if isinstance(row[col], (pd.Timestamp, pd.Timedelta)):
|
|
771
|
+
continue
|
|
772
|
+
group.attrs[col] = row[col]
|
|
773
|
+
f.close()
|
|
774
|
+
logger.info("Done encoding events.")
|
|
775
|
+
|
|
776
|
+
return
|
|
777
|
+
|
|
778
|
+
def run_pipeline(self, spark) -> None:
|
|
779
|
+
cohorts = self.build_cohorts(cached=self.cache)
|
|
780
|
+
labeled_cohorts = self.prepare_tasks(cohorts, spark, cached=self.cache)
|
|
781
|
+
events = self.process_tables(labeled_cohorts, spark)
|
|
782
|
+
self.make_input(labeled_cohorts, events, spark)
|
|
783
|
+
|
|
784
|
+
def add_special_tokens(self, new_special_tokens: Union[str, List]) -> None:
|
|
785
|
+
if isinstance(new_special_tokens, str):
|
|
786
|
+
new_special_tokens = [new_special_tokens]
|
|
787
|
+
|
|
788
|
+
num_special_tokens = self.num_special_tokens
|
|
789
|
+
overlapped = []
|
|
790
|
+
for new_special_token in new_special_tokens:
|
|
791
|
+
if new_special_token in self.special_tokens_dict:
|
|
792
|
+
overlapped.append(new_special_token)
|
|
793
|
+
|
|
794
|
+
if len(overlapped) > 0:
|
|
795
|
+
logger.warn(
|
|
796
|
+
"There are some tokens that have already been set to special tokens."
|
|
797
|
+
" Please provide only NEW tokens. Aborted."
|
|
798
|
+
)
|
|
799
|
+
return None
|
|
800
|
+
elif num_special_tokens + len(new_special_tokens) > self.max_special_tokens:
|
|
801
|
+
logger.warn(
|
|
802
|
+
f"Total additional special tokens should be less than {self.max_special_tokens}" " Aborted."
|
|
803
|
+
)
|
|
804
|
+
return None
|
|
805
|
+
|
|
806
|
+
self.special_tokens_dict.update(
|
|
807
|
+
{
|
|
808
|
+
k: "[unused{}]".format(i)
|
|
809
|
+
for i, k in enumerate(new_special_tokens, start=num_special_tokens + 1)
|
|
810
|
+
}
|
|
811
|
+
)
|
|
812
|
+
|
|
813
|
+
def make_compatible(self, icustays):
|
|
814
|
+
"""
|
|
815
|
+
make different ehrs compatible with one another here
|
|
816
|
+
NOTE: outtime/dischtime is converted to relative minutes from intime
|
|
817
|
+
but, maintain the intime as the original value for later use
|
|
818
|
+
"""
|
|
819
|
+
raise NotImplementedError()
|
|
820
|
+
|
|
821
|
+
def is_compatible(self, icustays):
|
|
822
|
+
checklist = [
|
|
823
|
+
self.hadm_key,
|
|
824
|
+
self.icustay_key,
|
|
825
|
+
self.patient_key,
|
|
826
|
+
"LOS",
|
|
827
|
+
"AGE",
|
|
828
|
+
"INTIME",
|
|
829
|
+
"OUTTIME",
|
|
830
|
+
"DISCHTIME",
|
|
831
|
+
"IN_ICU_MORTALITY",
|
|
832
|
+
"HOS_DISCHARGE_LOCATION",
|
|
833
|
+
]
|
|
834
|
+
for item in checklist:
|
|
835
|
+
if item not in icustays.columns.to_list():
|
|
836
|
+
return False
|
|
837
|
+
return True
|
|
838
|
+
|
|
839
|
+
def save_to_cache(self, f, fname, use_pickle=False) -> None:
|
|
840
|
+
if use_pickle:
|
|
841
|
+
import pickle
|
|
842
|
+
|
|
843
|
+
with open(os.path.join(self.cache_dir, fname), "wb") as fptr:
|
|
844
|
+
pickle.dump(f, fptr)
|
|
845
|
+
else:
|
|
846
|
+
f.to_pickle(os.path.join(self.cache_dir, fname))
|
|
847
|
+
|
|
848
|
+
def load_from_cache(self, fname):
|
|
849
|
+
cached = os.path.join(self.cache_dir, fname)
|
|
850
|
+
if os.path.exists(cached):
|
|
851
|
+
data = pd.read_pickle(cached)
|
|
852
|
+
|
|
853
|
+
logger.info("Loaded data from {}".format(cached))
|
|
854
|
+
return data
|
|
855
|
+
else:
|
|
856
|
+
return None
|
|
857
|
+
|
|
858
|
+
def infer_data_extension(self) -> str:
|
|
859
|
+
raise NotImplementedError()
|
|
860
|
+
|
|
861
|
+
def download_ehr_from_url(self, url, dest) -> None:
|
|
862
|
+
username = input("Email or Username: ")
|
|
863
|
+
subprocess.run(
|
|
864
|
+
[
|
|
865
|
+
"wget",
|
|
866
|
+
"-r",
|
|
867
|
+
"-N",
|
|
868
|
+
"-c",
|
|
869
|
+
"np",
|
|
870
|
+
"--user",
|
|
871
|
+
username,
|
|
872
|
+
"--ask-password",
|
|
873
|
+
url,
|
|
874
|
+
"-P",
|
|
875
|
+
dest,
|
|
876
|
+
]
|
|
877
|
+
)
|
|
878
|
+
output_dir = url.replace("https://", "").replace("http://", "")
|
|
879
|
+
|
|
880
|
+
if not os.path.exists(os.path.join(dest, output_dir)):
|
|
881
|
+
raise AssertionError(
|
|
882
|
+
"Download failed. Please check your network connection or "
|
|
883
|
+
"if you log in with a credentialed user"
|
|
884
|
+
)
|
|
885
|
+
|
|
886
|
+
def download_ccs_from_url(self, dest) -> None:
|
|
887
|
+
subprocess.run(
|
|
888
|
+
[
|
|
889
|
+
"wget",
|
|
890
|
+
"-N",
|
|
891
|
+
"-c",
|
|
892
|
+
"https://www.hcup-us.ahrq.gov/toolssoftware/ccs/Multi_Level_CCS_2015.zip",
|
|
893
|
+
"-P",
|
|
894
|
+
dest,
|
|
895
|
+
]
|
|
896
|
+
)
|
|
897
|
+
|
|
898
|
+
import zipfile
|
|
899
|
+
|
|
900
|
+
with zipfile.ZipFile(os.path.join(dest, "Multi_Level_CCS_2015.zip"), "r") as zip_ref:
|
|
901
|
+
zip_ref.extractall(os.path.join(dest, "foo.d"))
|
|
902
|
+
os.rename(
|
|
903
|
+
os.path.join(dest, "foo.d", "ccs_multi_dx_tool_2015.csv"),
|
|
904
|
+
os.path.join(dest, "ccs_multi_dx_tool_2015.csv"),
|
|
905
|
+
)
|
|
906
|
+
os.remove(os.path.join(dest, "Multi_Level_CCS_2015.zip"))
|
|
907
|
+
shutil.rmtree(os.path.join(dest, "foo.d"))
|
|
908
|
+
|
|
909
|
+
def download_icdgem_from_url(self, dest) -> None:
|
|
910
|
+
subprocess.run(
|
|
911
|
+
[
|
|
912
|
+
"wget",
|
|
913
|
+
"-N",
|
|
914
|
+
"-c",
|
|
915
|
+
"https://data.nber.org/gem/icd10cmtoicd9gem.csv",
|
|
916
|
+
"-P",
|
|
917
|
+
dest,
|
|
918
|
+
]
|
|
919
|
+
)
|