genhpf 1.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genhpf/__init__.py +9 -0
- genhpf/configs/__init__.py +23 -0
- genhpf/configs/config.yaml +8 -0
- genhpf/configs/configs.py +240 -0
- genhpf/configs/constants.py +29 -0
- genhpf/configs/initialize.py +58 -0
- genhpf/configs/utils.py +29 -0
- genhpf/criterions/__init__.py +74 -0
- genhpf/criterions/binary_cross_entropy.py +114 -0
- genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
- genhpf/criterions/criterion.py +87 -0
- genhpf/criterions/cross_entropy.py +202 -0
- genhpf/criterions/multi_task_criterion.py +177 -0
- genhpf/criterions/simclr_criterion.py +84 -0
- genhpf/criterions/wav2vec2_criterion.py +130 -0
- genhpf/datasets/__init__.py +84 -0
- genhpf/datasets/dataset.py +109 -0
- genhpf/datasets/genhpf_dataset.py +451 -0
- genhpf/datasets/meds_dataset.py +232 -0
- genhpf/loggings/__init__.py +0 -0
- genhpf/loggings/meters.py +374 -0
- genhpf/loggings/metrics.py +155 -0
- genhpf/loggings/progress_bar.py +445 -0
- genhpf/models/__init__.py +73 -0
- genhpf/models/genhpf.py +244 -0
- genhpf/models/genhpf_mlm.py +64 -0
- genhpf/models/genhpf_predictor.py +73 -0
- genhpf/models/genhpf_simclr.py +58 -0
- genhpf/models/genhpf_wav2vec2.py +304 -0
- genhpf/modules/__init__.py +15 -0
- genhpf/modules/gather_layer.py +23 -0
- genhpf/modules/grad_multiply.py +12 -0
- genhpf/modules/gumbel_vector_quantizer.py +204 -0
- genhpf/modules/identity_layer.py +8 -0
- genhpf/modules/layer_norm.py +27 -0
- genhpf/modules/positional_encoding.py +24 -0
- genhpf/scripts/__init__.py +0 -0
- genhpf/scripts/preprocess/__init__.py +0 -0
- genhpf/scripts/preprocess/genhpf/README.md +75 -0
- genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
- genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
- genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
- genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
- genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
- genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
- genhpf/scripts/preprocess/genhpf/main.py +175 -0
- genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
- genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
- genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
- genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
- genhpf/scripts/preprocess/manifest.py +83 -0
- genhpf/scripts/preprocess/preprocess_meds.py +674 -0
- genhpf/scripts/test.py +264 -0
- genhpf/scripts/train.py +365 -0
- genhpf/trainer.py +370 -0
- genhpf/utils/checkpoint_utils.py +171 -0
- genhpf/utils/data_utils.py +130 -0
- genhpf/utils/distributed_utils.py +497 -0
- genhpf/utils/file_io.py +170 -0
- genhpf/utils/pdb.py +38 -0
- genhpf/utils/utils.py +204 -0
- genhpf-1.0.11.dist-info/LICENSE +21 -0
- genhpf-1.0.11.dist-info/METADATA +202 -0
- genhpf-1.0.11.dist-info/RECORD +67 -0
- genhpf-1.0.11.dist-info/WHEEL +5 -0
- genhpf-1.0.11.dist-info/entry_points.txt +6 -0
- genhpf-1.0.11.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,550 @@
|
|
|
1
|
+
import glob
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
from collections import Counter
|
|
5
|
+
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import pyspark.sql.functions as F
|
|
8
|
+
import treelib
|
|
9
|
+
from ehrs import EHR, register_ehr
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@register_ehr("eicu")
|
|
15
|
+
class eICU(EHR):
|
|
16
|
+
def __init__(self, cfg):
|
|
17
|
+
super().__init__(cfg)
|
|
18
|
+
|
|
19
|
+
self.ehr_name = "eicu"
|
|
20
|
+
|
|
21
|
+
if self.data_dir is None:
|
|
22
|
+
self.data_dir = os.path.join(self.cache_dir, self.ehr_name)
|
|
23
|
+
|
|
24
|
+
if not os.path.exists(self.data_dir):
|
|
25
|
+
logger.info(
|
|
26
|
+
"Data is not found so try to download from the internet. "
|
|
27
|
+
"Note that this is a restricted-access resource. "
|
|
28
|
+
"Please log in to physionet.org with a credentialed user."
|
|
29
|
+
)
|
|
30
|
+
self.download_ehr_from_url(
|
|
31
|
+
url="https://physionet.org/files/eicu-crd/2.0/", dest=self.data_dir
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
logger.info("Data directory is set to {}".format(self.data_dir))
|
|
35
|
+
|
|
36
|
+
if self.ccs_path is None:
|
|
37
|
+
self.ccs_path = os.path.join(self.cache_dir, "ccs_multi_dx_tool_2015.csv")
|
|
38
|
+
|
|
39
|
+
if not os.path.exists(self.ccs_path):
|
|
40
|
+
logger.info("`ccs_multi_dx_tool_2015.csv` is not found so try to download from the internet.")
|
|
41
|
+
self.download_ccs_from_url(self.cache_dir)
|
|
42
|
+
|
|
43
|
+
if self.gem_path is None:
|
|
44
|
+
self.gem_path = os.path.join(self.cache_dir, "icd10cmtoicd9gem.csv")
|
|
45
|
+
|
|
46
|
+
if not os.path.exists(self.gem_path):
|
|
47
|
+
logger.info("`icd10cmtoicd9gem.csv` is not found so try to download from the internet.")
|
|
48
|
+
self.download_icdgem_from_url(self.cache_dir)
|
|
49
|
+
|
|
50
|
+
if self.ext is None:
|
|
51
|
+
self.ext = self.infer_data_extension()
|
|
52
|
+
|
|
53
|
+
self._icustay_fname = "patient" + self.ext
|
|
54
|
+
self._diagnosis_fname = "diagnosis" + self.ext
|
|
55
|
+
|
|
56
|
+
self.tables = [
|
|
57
|
+
{
|
|
58
|
+
"fname": "lab" + self.ext,
|
|
59
|
+
"timestamp": "labresultoffset",
|
|
60
|
+
"timeoffsetunit": "min",
|
|
61
|
+
"exclude": ["labid", "labresultrevisedoffset"],
|
|
62
|
+
},
|
|
63
|
+
{
|
|
64
|
+
"fname": "medication" + self.ext,
|
|
65
|
+
"timestamp": "drugstartoffset",
|
|
66
|
+
"timeoffsetunit": "min",
|
|
67
|
+
"exclude": [
|
|
68
|
+
"drugorderoffset",
|
|
69
|
+
"drugstopoffset",
|
|
70
|
+
"medicationid",
|
|
71
|
+
"gtc",
|
|
72
|
+
"drughiclseqno",
|
|
73
|
+
"drugordercancelled",
|
|
74
|
+
],
|
|
75
|
+
},
|
|
76
|
+
{
|
|
77
|
+
"fname": "infusionDrug" + self.ext,
|
|
78
|
+
"timestamp": "infusionoffset",
|
|
79
|
+
"timeoffsetunit": "min",
|
|
80
|
+
"exclude": ["infusiondrugid"],
|
|
81
|
+
},
|
|
82
|
+
]
|
|
83
|
+
|
|
84
|
+
if self.feature == "select":
|
|
85
|
+
extra_exclude_feature_dict = {
|
|
86
|
+
"lab"
|
|
87
|
+
+ self.ext: [
|
|
88
|
+
"labtypeid",
|
|
89
|
+
"labresulttext",
|
|
90
|
+
"labmeasurenameinterface",
|
|
91
|
+
"labresultrevisedoffset",
|
|
92
|
+
],
|
|
93
|
+
"medication" + self.ext: ["drugivadmixture" "frequency", "loadingdose", "prn"],
|
|
94
|
+
"infusionDrug" + self.ext: ["patientweight", "volumeoffluid", "drugrate", "drugamount"],
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
for table in self.tables:
|
|
98
|
+
if table["fname"] in extra_exclude_feature_dict.keys():
|
|
99
|
+
exclude_target_list = extra_exclude_feature_dict[table["fname"]]
|
|
100
|
+
table["exclude"].extend(exclude_target_list)
|
|
101
|
+
|
|
102
|
+
if self.emb_type == "codebase":
|
|
103
|
+
feature_types_for_codebase_emb_dict = {
|
|
104
|
+
"lab"
|
|
105
|
+
+ self.ext: {
|
|
106
|
+
"numeric_feat": ["labresult", "labresulttext"],
|
|
107
|
+
"categorical_feat": ["labtypeid"],
|
|
108
|
+
"code_feat": ["labname"],
|
|
109
|
+
},
|
|
110
|
+
"medication"
|
|
111
|
+
+ self.ext: {
|
|
112
|
+
"numeric_feat": ["dosage"],
|
|
113
|
+
"categorical_feat": ["drugordercancelled", "drugivadmixture"],
|
|
114
|
+
"code_feat": ["drugname"],
|
|
115
|
+
},
|
|
116
|
+
"infusionDrug"
|
|
117
|
+
+ self.ext: {
|
|
118
|
+
"numeric_feat": [
|
|
119
|
+
"drugrate",
|
|
120
|
+
"infusionrate",
|
|
121
|
+
"drugamount",
|
|
122
|
+
"volumeoffluid",
|
|
123
|
+
"patientweight",
|
|
124
|
+
],
|
|
125
|
+
"categorical_feat": [],
|
|
126
|
+
"code_feat": ["drugname"],
|
|
127
|
+
},
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
for table in self.tables:
|
|
131
|
+
if table["fname"] in feature_types_for_codebase_emb_dict.keys():
|
|
132
|
+
feature_dict = feature_types_for_codebase_emb_dict[table["fname"]]
|
|
133
|
+
table.update(feature_dict)
|
|
134
|
+
|
|
135
|
+
if self.creatinine or self.bilirubin or self.platelets or self.wbc:
|
|
136
|
+
self.task_itemids = {
|
|
137
|
+
"creatinine": {
|
|
138
|
+
"fname": "lab" + self.ext,
|
|
139
|
+
"timestamp": "labresultoffset",
|
|
140
|
+
"timeoffsetunit": "min",
|
|
141
|
+
"exclude": [
|
|
142
|
+
"labtypeid",
|
|
143
|
+
"labresulttext",
|
|
144
|
+
"labmeasurenamesystem",
|
|
145
|
+
"labmeasurenameinterface",
|
|
146
|
+
"labresultrevisedoffset",
|
|
147
|
+
],
|
|
148
|
+
"code": ["labname"],
|
|
149
|
+
"value": ["labresult"],
|
|
150
|
+
"itemid": ["creatinine"],
|
|
151
|
+
},
|
|
152
|
+
"bilirubin": {
|
|
153
|
+
"fname": "lab" + self.ext,
|
|
154
|
+
"timestamp": "labresultoffset",
|
|
155
|
+
"timeoffsetunit": "min",
|
|
156
|
+
"exclude": [
|
|
157
|
+
"labtypeid",
|
|
158
|
+
"labresulttext",
|
|
159
|
+
"labmeasurenamesystem",
|
|
160
|
+
"labmeasurenameinterface",
|
|
161
|
+
"labresultrevisedoffset",
|
|
162
|
+
],
|
|
163
|
+
"code": ["labname"],
|
|
164
|
+
"value": ["labresult"],
|
|
165
|
+
"itemid": ["total bilirubin"],
|
|
166
|
+
},
|
|
167
|
+
"platelets": {
|
|
168
|
+
"fname": "lab" + self.ext,
|
|
169
|
+
"timestamp": "labresultoffset",
|
|
170
|
+
"timeoffsetunit": "min",
|
|
171
|
+
"exclude": [
|
|
172
|
+
"labtypeid",
|
|
173
|
+
"labresulttext",
|
|
174
|
+
"labmeasurenamesystem",
|
|
175
|
+
"labmeasurenameinterface",
|
|
176
|
+
"labresultrevisedoffset",
|
|
177
|
+
],
|
|
178
|
+
"code": ["labname"],
|
|
179
|
+
"value": ["labresult"],
|
|
180
|
+
"itemid": ["platelets x 1000"],
|
|
181
|
+
},
|
|
182
|
+
"wbc": {
|
|
183
|
+
"fname": "lab" + self.ext,
|
|
184
|
+
"timestamp": "labresultoffset",
|
|
185
|
+
"timeoffsetunit": "min",
|
|
186
|
+
"exclude": [
|
|
187
|
+
"labtypeid",
|
|
188
|
+
"labresulttext",
|
|
189
|
+
"labmeasurenamesystem",
|
|
190
|
+
"labmeasurenameinterface",
|
|
191
|
+
"labresultrevisedoffset",
|
|
192
|
+
],
|
|
193
|
+
"code": ["labname"],
|
|
194
|
+
"value": ["labresult"],
|
|
195
|
+
"itemid": ["WBC x 1000"],
|
|
196
|
+
},
|
|
197
|
+
"dialysis": {
|
|
198
|
+
"fname": "intakeOutput" + self.ext,
|
|
199
|
+
"timestamp": "intakeoutputoffset",
|
|
200
|
+
"timeoffsetunit": "min",
|
|
201
|
+
"exclude": [
|
|
202
|
+
"intakeoutputid",
|
|
203
|
+
"intaketotal",
|
|
204
|
+
"outputtotal",
|
|
205
|
+
"nettotal",
|
|
206
|
+
"intakeoutputentryoffset",
|
|
207
|
+
],
|
|
208
|
+
"code": ["dialysistotal"],
|
|
209
|
+
"value": [],
|
|
210
|
+
"itemid": [],
|
|
211
|
+
},
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
self.disch_map_dict = {
|
|
215
|
+
"Home": "Home",
|
|
216
|
+
"IN_ICU_MORTALITY": "IN_ICU_MORTALITY",
|
|
217
|
+
"Nursing Home": "Other",
|
|
218
|
+
"Other": "Other",
|
|
219
|
+
"Other External": "Other",
|
|
220
|
+
"Other Hospital": "Other",
|
|
221
|
+
"Rehabilitation": "Rehabilitation",
|
|
222
|
+
"Skilled Nursing Facility": "Skilled Nursing Facility",
|
|
223
|
+
"Death": "Death",
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
self._icustay_key = "patientunitstayid"
|
|
227
|
+
self._hadm_key = "patienthealthsystemstayid"
|
|
228
|
+
self._patient_key = "uniquepid"
|
|
229
|
+
|
|
230
|
+
self._determine_first_icu = "unitvisitnumber"
|
|
231
|
+
|
|
232
|
+
def build_cohorts(self, cached=False):
|
|
233
|
+
icustays = pd.read_csv(os.path.join(self.data_dir, self.icustay_fname))
|
|
234
|
+
|
|
235
|
+
icustays = self.make_compatible(icustays)
|
|
236
|
+
self.icustays = icustays
|
|
237
|
+
|
|
238
|
+
cohorts = super().build_cohorts(icustays, cached=cached)
|
|
239
|
+
|
|
240
|
+
return cohorts
|
|
241
|
+
|
|
242
|
+
def prepare_tasks(self, cohorts, spark, cached=False):
|
|
243
|
+
if cohorts is None and cached:
|
|
244
|
+
labeled_cohorts = self.load_from_cache(self.ehr_name + ".cohorts.labeled")
|
|
245
|
+
if labeled_cohorts is not None:
|
|
246
|
+
return labeled_cohorts
|
|
247
|
+
|
|
248
|
+
labeled_cohorts = super().prepare_tasks(cohorts, spark, cached)
|
|
249
|
+
|
|
250
|
+
if self.diagnosis:
|
|
251
|
+
logger.info("Start labeling cohorts for diagnosis prediction.")
|
|
252
|
+
|
|
253
|
+
str2cat = self.make_dx_mapping()
|
|
254
|
+
dx = pd.read_csv(os.path.join(self.data_dir, self.diagnosis_fname))
|
|
255
|
+
dx = dx.merge(cohorts[[self.icustay_key, self.hadm_key]], on=self.icustay_key)
|
|
256
|
+
dx["diagnosis"] = dx["diagnosisstring"].map(lambda x: str2cat.get(x, -1))
|
|
257
|
+
# Ignore Rare Class(14)
|
|
258
|
+
dx = dx[(dx["diagnosis"] != -1) & (dx["diagnosis"] != 14)]
|
|
259
|
+
dx.loc[dx["diagnosis"] >= 14, "diagnosis"] -= 1
|
|
260
|
+
dx = dx.groupby(self.hadm_key)["diagnosis"].agg(lambda x: list(set(x))).to_frame()
|
|
261
|
+
|
|
262
|
+
labeled_cohorts = labeled_cohorts.merge(dx, on=self.hadm_key, how="left")
|
|
263
|
+
labeled_cohorts["diagnosis"] = labeled_cohorts["diagnosis"].apply(
|
|
264
|
+
lambda x: [] if type(x) is not list else x
|
|
265
|
+
)
|
|
266
|
+
# NaN case in diagnosis -> []
|
|
267
|
+
|
|
268
|
+
logger.info("Done preparing diagnosis prediction for the given cohorts")
|
|
269
|
+
|
|
270
|
+
self.save_to_cache(labeled_cohorts, self.ehr_name + ".cohorts.labeled")
|
|
271
|
+
|
|
272
|
+
if self.bilirubin or self.platelets or self.creatinine or self.wbc:
|
|
273
|
+
logger.info("Start labeling cohorts for clinical task prediction.")
|
|
274
|
+
|
|
275
|
+
labeled_cohorts = spark.createDataFrame(labeled_cohorts)
|
|
276
|
+
|
|
277
|
+
if self.bilirubin:
|
|
278
|
+
labeled_cohorts = self.clinical_task(labeled_cohorts, "bilirubin", spark)
|
|
279
|
+
|
|
280
|
+
if self.platelets:
|
|
281
|
+
labeled_cohorts = self.clinical_task(labeled_cohorts, "platelets", spark)
|
|
282
|
+
|
|
283
|
+
if self.creatinine:
|
|
284
|
+
labeled_cohorts = self.clinical_task(labeled_cohorts, "creatinine", spark)
|
|
285
|
+
|
|
286
|
+
if self.wbc:
|
|
287
|
+
labeled_cohorts = self.clinical_task(labeled_cohorts, "wbc", spark)
|
|
288
|
+
|
|
289
|
+
# self.save_to_cache(labeled_cohorts, self.ehr_name + ".cohorts.labeled.clinical_tasks")
|
|
290
|
+
|
|
291
|
+
logger.info("Done preparing clinical task prediction for the given cohorts")
|
|
292
|
+
|
|
293
|
+
return labeled_cohorts
|
|
294
|
+
|
|
295
|
+
def make_compatible(self, icustays):
|
|
296
|
+
icustays.loc[:, "LOS"] = icustays["unitdischargeoffset"] / 60 / 24
|
|
297
|
+
icustays.dropna(subset=["age"], inplace=True)
|
|
298
|
+
icustays["AGE"] = icustays["age"].replace("> 89", 300).astype(int)
|
|
299
|
+
|
|
300
|
+
# hacks for compatibility with other ehrs
|
|
301
|
+
icustays["INTIME"] = 0
|
|
302
|
+
icustays.rename(columns={"unitdischargeoffset": "OUTTIME"}, inplace=True)
|
|
303
|
+
# DEATHTIME
|
|
304
|
+
# icustays["DEATHTIME"] = np.nan
|
|
305
|
+
# is_discharged_in_icu = icustays["unitdischargestatus"] == "Expired"
|
|
306
|
+
# icustays.loc[is_discharged_in_icu, "DEATHTIME"] = (
|
|
307
|
+
# icustays.loc[is_discharged_in_icu, "OUTTIME"]
|
|
308
|
+
# )
|
|
309
|
+
# is_discharged_in_hos = (
|
|
310
|
+
# (icustays["unitdischargestatus"] != "Expired")
|
|
311
|
+
# & (icustays["hospitaldischargestatus"] == "Expired")
|
|
312
|
+
# )
|
|
313
|
+
# icustays.loc[is_discharged_in_hos, "DEATHTIME"] = (
|
|
314
|
+
# icustays.loc[is_discharged_in_hos, "OUTTIME"]
|
|
315
|
+
# ) + 1
|
|
316
|
+
|
|
317
|
+
icustays.rename(columns={"hospitaldischargeoffset": "DISCHTIME"}, inplace=True)
|
|
318
|
+
|
|
319
|
+
icustays["IN_ICU_MORTALITY"] = icustays["unitdischargestatus"] == "Expired"
|
|
320
|
+
icustays["hospitaldischargelocation"] = icustays["hospitaldischargelocation"].map(self.disch_map_dict)
|
|
321
|
+
icustays.rename(columns={"hospitaldischargelocation": "HOS_DISCHARGE_LOCATION"}, inplace=True)
|
|
322
|
+
|
|
323
|
+
return icustays
|
|
324
|
+
|
|
325
|
+
def make_dx_mapping(self):
|
|
326
|
+
diagnosis = pd.read_csv(os.path.join(self.data_dir, self.diagnosis_fname))
|
|
327
|
+
ccs_dx = pd.read_csv(self.ccs_path)
|
|
328
|
+
gem = pd.read_csv(self.gem_path)
|
|
329
|
+
|
|
330
|
+
diagnosis = diagnosis[["diagnosisstring", "icd9code"]]
|
|
331
|
+
|
|
332
|
+
# 1 to 1 matching btw str and code
|
|
333
|
+
# STR: diagnosisstring, CODE: icd9/10 code, CAT:category
|
|
334
|
+
# 1 str -> multiple code -> one cat
|
|
335
|
+
|
|
336
|
+
# 1. make str -> code dictonary
|
|
337
|
+
str2code = diagnosis.dropna(subset=["icd9code"])
|
|
338
|
+
str2code = str2code.groupby("diagnosisstring").first().reset_index()
|
|
339
|
+
str2code["icd9code"] = str2code["icd9code"].str.split(",")
|
|
340
|
+
str2code = str2code.explode("icd9code")
|
|
341
|
+
str2code["icd9code"] = str2code["icd9code"].str.replace(".", "", regex=False)
|
|
342
|
+
# str2code = dict(zip(notnull_dx["diagnosisstring"], notnull_dx["icd9code"]))
|
|
343
|
+
# 이거 하면 dxstring duplicated 자동 제거됨 ->x
|
|
344
|
+
|
|
345
|
+
ccs_dx["'ICD-9-CM CODE'"] = ccs_dx["'ICD-9-CM CODE'"].str[1:-1].str.strip()
|
|
346
|
+
ccs_dx["'CCS LVL 1'"] = ccs_dx["'CCS LVL 1'"].str[1:-1].astype(int) - 1
|
|
347
|
+
icd2cat = dict(zip(ccs_dx["'ICD-9-CM CODE'"], ccs_dx["'CCS LVL 1'"]))
|
|
348
|
+
|
|
349
|
+
# 2. if code is not icd9, convert it to icd9
|
|
350
|
+
str2code_icd10 = str2code[str2code["icd9code"].isin(icd2cat.keys())]
|
|
351
|
+
|
|
352
|
+
map_cms = dict(zip(gem["icd10cm"], gem["icd9cm"]))
|
|
353
|
+
map_manual = dict.fromkeys(set(str2code_icd10["icd9code"]) - set(gem["icd10cm"]), "NaN")
|
|
354
|
+
|
|
355
|
+
for code_10 in map_manual:
|
|
356
|
+
for i in range(len(code_10), 0, -1):
|
|
357
|
+
tgt_10 = code_10[:i]
|
|
358
|
+
if tgt_10 in gem["icd10cm"]:
|
|
359
|
+
tgt_9 = gem[gem["icd10cm"].str.contains(tgt_10)]["icd9cm"].mode().iloc[0]
|
|
360
|
+
map_manual[code_10] = tgt_9
|
|
361
|
+
break
|
|
362
|
+
icd102icd9 = {**map_cms, **map_manual}
|
|
363
|
+
|
|
364
|
+
# 3. Convert Available Strings to category
|
|
365
|
+
str2cat = {}
|
|
366
|
+
for _, row in str2code.iterrows():
|
|
367
|
+
k, v = row
|
|
368
|
+
if v in icd2cat.keys():
|
|
369
|
+
cat = icd2cat[v]
|
|
370
|
+
if k in str2cat.keys() and str2cat[k] != cat:
|
|
371
|
+
logger.warning(f"{k} has multiple categories{cat, str2cat[k]}")
|
|
372
|
+
str2cat[k] = icd2cat[v]
|
|
373
|
+
elif v in icd102icd9.keys():
|
|
374
|
+
cat = icd2cat[icd102icd9[v]]
|
|
375
|
+
if k in str2cat.keys() and str2cat[k] != cat:
|
|
376
|
+
logger.warning(f"{k} has multiple categories{cat, str2cat[k]}")
|
|
377
|
+
str2cat[k] = icd2cat[icd102icd9[v]]
|
|
378
|
+
|
|
379
|
+
# 4. If no available category by mapping(~25%), use diagnosisstring hierarchy
|
|
380
|
+
|
|
381
|
+
# Make tree structure
|
|
382
|
+
tree = treelib.Tree()
|
|
383
|
+
tree.create_node("root", "root")
|
|
384
|
+
for dx, cat in str2cat.items():
|
|
385
|
+
dx = dx.split("|")
|
|
386
|
+
if not tree.contains(dx[0]):
|
|
387
|
+
tree.create_node(-1, dx[0], parent="root")
|
|
388
|
+
for i in range(2, len(dx)):
|
|
389
|
+
if not tree.contains("|".join(dx[:i])):
|
|
390
|
+
tree.create_node(-1, "|".join(dx[:i]), parent="|".join(dx[: i - 1]))
|
|
391
|
+
if not tree.contains("|".join(dx)):
|
|
392
|
+
tree.create_node(cat, "|".join(dx), parent="|".join(dx[:-1]))
|
|
393
|
+
|
|
394
|
+
# Update non-leaf nodes with majority vote
|
|
395
|
+
nid_list = list(tree.expand_tree(mode=treelib.Tree.DEPTH))
|
|
396
|
+
nid_list.reverse()
|
|
397
|
+
for nid in nid_list:
|
|
398
|
+
if tree.get_node(nid).is_leaf():
|
|
399
|
+
continue
|
|
400
|
+
elif tree.get_node(nid).tag == -1:
|
|
401
|
+
tree.get_node(nid).tag = Counter([child.tag for child in tree.children(nid)]).most_common(1)[
|
|
402
|
+
0
|
|
403
|
+
][0]
|
|
404
|
+
|
|
405
|
+
# Evaluate dxs without category
|
|
406
|
+
unmatched_dxs = set(diagnosis["diagnosisstring"]) - set(str2cat.keys())
|
|
407
|
+
for dx in unmatched_dxs:
|
|
408
|
+
dx = dx.split("|")
|
|
409
|
+
# Do not go to root level(can add noise)
|
|
410
|
+
for i in range(len(dx) - 1, 1, -1):
|
|
411
|
+
if tree.contains("|".join(dx[:i])):
|
|
412
|
+
str2cat["|".join(dx)] = tree.get_node("|".join(dx[:i])).tag
|
|
413
|
+
break
|
|
414
|
+
|
|
415
|
+
return str2cat
|
|
416
|
+
|
|
417
|
+
def clinical_task(self, cohorts, task, spark):
|
|
418
|
+
fname = self.task_itemids[task]["fname"]
|
|
419
|
+
timestamp = self.task_itemids[task]["timestamp"]
|
|
420
|
+
excludes = self.task_itemids[task]["exclude"]
|
|
421
|
+
code = self.task_itemids[task]["code"][0]
|
|
422
|
+
value = self.task_itemids[task]["value"][0]
|
|
423
|
+
itemid = self.task_itemids[task]["itemid"]
|
|
424
|
+
|
|
425
|
+
table = spark.read.csv(os.path.join(self.data_dir, fname), header=True)
|
|
426
|
+
table = table.drop(*excludes)
|
|
427
|
+
table = table.filter(F.col(code).isin(itemid)).filter(F.col(value).isNotNull())
|
|
428
|
+
|
|
429
|
+
merge = cohorts.join(table, on=self.icustay_key, how="inner")
|
|
430
|
+
|
|
431
|
+
if task == "creatinine":
|
|
432
|
+
patient = spark.read.csv(os.path.join(self.data_dir, self._icustay_fname), header=True)
|
|
433
|
+
patient = patient.select(*[self.patient_key, self.icustay_key, self._hadm_key]) # icuunit intime
|
|
434
|
+
multi_hosp = (
|
|
435
|
+
patient.groupBy(self.patient_key)
|
|
436
|
+
.agg(F.count(self._hadm_key).alias("count"))
|
|
437
|
+
.filter(F.col("count") > 1)
|
|
438
|
+
.select(self.patient_key)
|
|
439
|
+
.rdd.flatMap(lambda row: row)
|
|
440
|
+
.collect()
|
|
441
|
+
)
|
|
442
|
+
# multiple hosp
|
|
443
|
+
|
|
444
|
+
dialysis_tables = self.task_itemids["dialysis"]["fname"] # Only treatment for dialysis
|
|
445
|
+
dialysis_code = self.task_itemids["dialysis"]["code"][0]
|
|
446
|
+
excludes = self.task_itemids["dialysis"]["exclude"]
|
|
447
|
+
|
|
448
|
+
io = spark.read.csv(os.path.join(self.data_dir, dialysis_tables), header=True)
|
|
449
|
+
io = io.drop(*excludes)
|
|
450
|
+
|
|
451
|
+
io_dialysis = io.filter(F.col(dialysis_code) != 0)
|
|
452
|
+
io_dialysis = io_dialysis.join(patient, on=self.icustay_key, how="left")
|
|
453
|
+
|
|
454
|
+
dialysis_multihosp = (
|
|
455
|
+
io_dialysis.filter(F.col(self.patient_key).isin(multi_hosp))
|
|
456
|
+
.select(self.patient_key)
|
|
457
|
+
.rdd.flatMap(lambda row: row)
|
|
458
|
+
.collect()
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
io_dialysis = io_dialysis.drop(self.patient_key)
|
|
462
|
+
|
|
463
|
+
def dialysis_time(table, timecolumn):
|
|
464
|
+
return table.withColumn("_DIALYSIS_TIME", F.col(timecolumn)).select(
|
|
465
|
+
self.icustay_key, "_DIALYSIS_TIME"
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
io_dialysis = dialysis_time(io_dialysis, self.task_itemids["dialysis"]["timestamp"])
|
|
469
|
+
io_dialysis = io_dialysis.groupBy(self.icustay_key).agg(
|
|
470
|
+
F.min("_DIALYSIS_TIME").alias("_DIALYSIS_TIME")
|
|
471
|
+
)
|
|
472
|
+
io_dialysis = io_dialysis.select([self.icustay_key, "_DIALYSIS_TIME"])
|
|
473
|
+
merge = merge.join(io_dialysis, on=self.icustay_key, how="left")
|
|
474
|
+
merge = merge.filter(F.isnull("_DIALYSIS_TIME") | (F.col("_DIALYSIS_TIME") > F.col(timestamp)))
|
|
475
|
+
merge = merge.drop("_DIALYSIS_TIME")
|
|
476
|
+
|
|
477
|
+
# For Creatinine task, eliminate icus if patient went through dialysis treatment
|
|
478
|
+
# before (obs_size + pred_size) timestamp
|
|
479
|
+
|
|
480
|
+
# Cohort with events within (obs_size + gap_size) - (obs_size + pred_size)
|
|
481
|
+
merge = merge.filter(((self.obs_size + self.gap_size) * 60) <= F.col(timestamp)).filter(
|
|
482
|
+
((self.obs_size + self.pred_size) * 60) >= F.col(timestamp)
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
# Average value of events
|
|
486
|
+
value_agg = merge.groupBy(self.icustay_key).agg(
|
|
487
|
+
F.mean(value).alias("avg_value")
|
|
488
|
+
) # TODO: mean/min/max?
|
|
489
|
+
|
|
490
|
+
# Labeling
|
|
491
|
+
if task == "bilirubin":
|
|
492
|
+
value_agg = value_agg.withColumn(
|
|
493
|
+
task,
|
|
494
|
+
F.when(value_agg.avg_value < 1.2, 0)
|
|
495
|
+
.when((value_agg.avg_value >= 1.2) & (value_agg.avg_value < 2.0), 1)
|
|
496
|
+
.when((value_agg.avg_value >= 2.0) & (value_agg.avg_value < 6.0), 2)
|
|
497
|
+
.when((value_agg.avg_value >= 6.0) & (value_agg.avg_value < 12.0), 3)
|
|
498
|
+
.when(value_agg.avg_value >= 12.0, 4),
|
|
499
|
+
)
|
|
500
|
+
elif task == "platelets":
|
|
501
|
+
value_agg = value_agg.withColumn(
|
|
502
|
+
task,
|
|
503
|
+
F.when(value_agg.avg_value >= 150, 0)
|
|
504
|
+
.when((value_agg.avg_value >= 100) & (value_agg.avg_value < 150), 1)
|
|
505
|
+
.when((value_agg.avg_value >= 50) & (value_agg.avg_value < 100), 2)
|
|
506
|
+
.when((value_agg.avg_value >= 20) & (value_agg.avg_value < 50), 3)
|
|
507
|
+
.when(value_agg.avg_value < 20, 4),
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
elif task == "creatinine":
|
|
511
|
+
value_agg = value_agg.join(
|
|
512
|
+
patient.select([self.patient_key, self.icustay_key]), on=self.icustay_key, how="left"
|
|
513
|
+
)
|
|
514
|
+
value_agg = value_agg.withColumn(
|
|
515
|
+
task,
|
|
516
|
+
F.when(value_agg.avg_value < 1.2, 0)
|
|
517
|
+
.when((value_agg.avg_value >= 1.2) & (value_agg.avg_value < 2.0), 1)
|
|
518
|
+
.when((value_agg.avg_value >= 2.0) & (value_agg.avg_value < 3.5), 2)
|
|
519
|
+
.when((value_agg.avg_value >= 3.5) & (value_agg.avg_value < 5), 3)
|
|
520
|
+
.when(value_agg.avg_value >= 5, 4),
|
|
521
|
+
)
|
|
522
|
+
value_agg = value_agg.filter(~F.col(self.patient_key).isin(dialysis_multihosp))
|
|
523
|
+
value_agg = value_agg.drop(self.patient_key)
|
|
524
|
+
|
|
525
|
+
elif task == "wbc":
|
|
526
|
+
value_agg = value_agg.withColumn(
|
|
527
|
+
task,
|
|
528
|
+
F.when(value_agg.avg_value < 4, 0)
|
|
529
|
+
.when((value_agg.avg_value >= 4) & (value_agg.avg_value <= 12), 1)
|
|
530
|
+
.when((value_agg.avg_value > 12), 2),
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
cohorts = cohorts.join(value_agg.select(self.icustay_key, task), on=self.icustay_key, how="left")
|
|
534
|
+
|
|
535
|
+
return cohorts
|
|
536
|
+
|
|
537
|
+
def infer_data_extension(self) -> str:
|
|
538
|
+
if len(glob.glob(os.path.join(self.data_dir, "*.csv.gz"))) == 31:
|
|
539
|
+
ext = ".csv.gz"
|
|
540
|
+
elif len(glob.glob(os.path.join(self.data_dir, "*.csv"))) == 31:
|
|
541
|
+
ext = ".csv"
|
|
542
|
+
else:
|
|
543
|
+
raise AssertionError(
|
|
544
|
+
"Provided data directory is not correct. Please check if --data is correct. "
|
|
545
|
+
"--data: {}".format(self.data_dir)
|
|
546
|
+
)
|
|
547
|
+
|
|
548
|
+
logger.info("Data extension is set to '{}'".format(ext))
|
|
549
|
+
|
|
550
|
+
return ext
|