genhpf 1.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. genhpf/__init__.py +9 -0
  2. genhpf/configs/__init__.py +23 -0
  3. genhpf/configs/config.yaml +8 -0
  4. genhpf/configs/configs.py +240 -0
  5. genhpf/configs/constants.py +29 -0
  6. genhpf/configs/initialize.py +58 -0
  7. genhpf/configs/utils.py +29 -0
  8. genhpf/criterions/__init__.py +74 -0
  9. genhpf/criterions/binary_cross_entropy.py +114 -0
  10. genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
  11. genhpf/criterions/criterion.py +87 -0
  12. genhpf/criterions/cross_entropy.py +202 -0
  13. genhpf/criterions/multi_task_criterion.py +177 -0
  14. genhpf/criterions/simclr_criterion.py +84 -0
  15. genhpf/criterions/wav2vec2_criterion.py +130 -0
  16. genhpf/datasets/__init__.py +84 -0
  17. genhpf/datasets/dataset.py +109 -0
  18. genhpf/datasets/genhpf_dataset.py +451 -0
  19. genhpf/datasets/meds_dataset.py +232 -0
  20. genhpf/loggings/__init__.py +0 -0
  21. genhpf/loggings/meters.py +374 -0
  22. genhpf/loggings/metrics.py +155 -0
  23. genhpf/loggings/progress_bar.py +445 -0
  24. genhpf/models/__init__.py +73 -0
  25. genhpf/models/genhpf.py +244 -0
  26. genhpf/models/genhpf_mlm.py +64 -0
  27. genhpf/models/genhpf_predictor.py +73 -0
  28. genhpf/models/genhpf_simclr.py +58 -0
  29. genhpf/models/genhpf_wav2vec2.py +304 -0
  30. genhpf/modules/__init__.py +15 -0
  31. genhpf/modules/gather_layer.py +23 -0
  32. genhpf/modules/grad_multiply.py +12 -0
  33. genhpf/modules/gumbel_vector_quantizer.py +204 -0
  34. genhpf/modules/identity_layer.py +8 -0
  35. genhpf/modules/layer_norm.py +27 -0
  36. genhpf/modules/positional_encoding.py +24 -0
  37. genhpf/scripts/__init__.py +0 -0
  38. genhpf/scripts/preprocess/__init__.py +0 -0
  39. genhpf/scripts/preprocess/genhpf/README.md +75 -0
  40. genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
  41. genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
  42. genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
  43. genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
  44. genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
  45. genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
  46. genhpf/scripts/preprocess/genhpf/main.py +175 -0
  47. genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
  48. genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
  49. genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
  50. genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
  51. genhpf/scripts/preprocess/manifest.py +83 -0
  52. genhpf/scripts/preprocess/preprocess_meds.py +674 -0
  53. genhpf/scripts/test.py +264 -0
  54. genhpf/scripts/train.py +365 -0
  55. genhpf/trainer.py +370 -0
  56. genhpf/utils/checkpoint_utils.py +171 -0
  57. genhpf/utils/data_utils.py +130 -0
  58. genhpf/utils/distributed_utils.py +497 -0
  59. genhpf/utils/file_io.py +170 -0
  60. genhpf/utils/pdb.py +38 -0
  61. genhpf/utils/utils.py +204 -0
  62. genhpf-1.0.11.dist-info/LICENSE +21 -0
  63. genhpf-1.0.11.dist-info/METADATA +202 -0
  64. genhpf-1.0.11.dist-info/RECORD +67 -0
  65. genhpf-1.0.11.dist-info/WHEEL +5 -0
  66. genhpf-1.0.11.dist-info/entry_points.txt +6 -0
  67. genhpf-1.0.11.dist-info/top_level.txt +1 -0
@@ -0,0 +1,550 @@
1
+ import glob
2
+ import logging
3
+ import os
4
+ from collections import Counter
5
+
6
+ import pandas as pd
7
+ import pyspark.sql.functions as F
8
+ import treelib
9
+ from ehrs import EHR, register_ehr
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ @register_ehr("eicu")
15
+ class eICU(EHR):
16
+ def __init__(self, cfg):
17
+ super().__init__(cfg)
18
+
19
+ self.ehr_name = "eicu"
20
+
21
+ if self.data_dir is None:
22
+ self.data_dir = os.path.join(self.cache_dir, self.ehr_name)
23
+
24
+ if not os.path.exists(self.data_dir):
25
+ logger.info(
26
+ "Data is not found so try to download from the internet. "
27
+ "Note that this is a restricted-access resource. "
28
+ "Please log in to physionet.org with a credentialed user."
29
+ )
30
+ self.download_ehr_from_url(
31
+ url="https://physionet.org/files/eicu-crd/2.0/", dest=self.data_dir
32
+ )
33
+
34
+ logger.info("Data directory is set to {}".format(self.data_dir))
35
+
36
+ if self.ccs_path is None:
37
+ self.ccs_path = os.path.join(self.cache_dir, "ccs_multi_dx_tool_2015.csv")
38
+
39
+ if not os.path.exists(self.ccs_path):
40
+ logger.info("`ccs_multi_dx_tool_2015.csv` is not found so try to download from the internet.")
41
+ self.download_ccs_from_url(self.cache_dir)
42
+
43
+ if self.gem_path is None:
44
+ self.gem_path = os.path.join(self.cache_dir, "icd10cmtoicd9gem.csv")
45
+
46
+ if not os.path.exists(self.gem_path):
47
+ logger.info("`icd10cmtoicd9gem.csv` is not found so try to download from the internet.")
48
+ self.download_icdgem_from_url(self.cache_dir)
49
+
50
+ if self.ext is None:
51
+ self.ext = self.infer_data_extension()
52
+
53
+ self._icustay_fname = "patient" + self.ext
54
+ self._diagnosis_fname = "diagnosis" + self.ext
55
+
56
+ self.tables = [
57
+ {
58
+ "fname": "lab" + self.ext,
59
+ "timestamp": "labresultoffset",
60
+ "timeoffsetunit": "min",
61
+ "exclude": ["labid", "labresultrevisedoffset"],
62
+ },
63
+ {
64
+ "fname": "medication" + self.ext,
65
+ "timestamp": "drugstartoffset",
66
+ "timeoffsetunit": "min",
67
+ "exclude": [
68
+ "drugorderoffset",
69
+ "drugstopoffset",
70
+ "medicationid",
71
+ "gtc",
72
+ "drughiclseqno",
73
+ "drugordercancelled",
74
+ ],
75
+ },
76
+ {
77
+ "fname": "infusionDrug" + self.ext,
78
+ "timestamp": "infusionoffset",
79
+ "timeoffsetunit": "min",
80
+ "exclude": ["infusiondrugid"],
81
+ },
82
+ ]
83
+
84
+ if self.feature == "select":
85
+ extra_exclude_feature_dict = {
86
+ "lab"
87
+ + self.ext: [
88
+ "labtypeid",
89
+ "labresulttext",
90
+ "labmeasurenameinterface",
91
+ "labresultrevisedoffset",
92
+ ],
93
+ "medication" + self.ext: ["drugivadmixture" "frequency", "loadingdose", "prn"],
94
+ "infusionDrug" + self.ext: ["patientweight", "volumeoffluid", "drugrate", "drugamount"],
95
+ }
96
+
97
+ for table in self.tables:
98
+ if table["fname"] in extra_exclude_feature_dict.keys():
99
+ exclude_target_list = extra_exclude_feature_dict[table["fname"]]
100
+ table["exclude"].extend(exclude_target_list)
101
+
102
+ if self.emb_type == "codebase":
103
+ feature_types_for_codebase_emb_dict = {
104
+ "lab"
105
+ + self.ext: {
106
+ "numeric_feat": ["labresult", "labresulttext"],
107
+ "categorical_feat": ["labtypeid"],
108
+ "code_feat": ["labname"],
109
+ },
110
+ "medication"
111
+ + self.ext: {
112
+ "numeric_feat": ["dosage"],
113
+ "categorical_feat": ["drugordercancelled", "drugivadmixture"],
114
+ "code_feat": ["drugname"],
115
+ },
116
+ "infusionDrug"
117
+ + self.ext: {
118
+ "numeric_feat": [
119
+ "drugrate",
120
+ "infusionrate",
121
+ "drugamount",
122
+ "volumeoffluid",
123
+ "patientweight",
124
+ ],
125
+ "categorical_feat": [],
126
+ "code_feat": ["drugname"],
127
+ },
128
+ }
129
+
130
+ for table in self.tables:
131
+ if table["fname"] in feature_types_for_codebase_emb_dict.keys():
132
+ feature_dict = feature_types_for_codebase_emb_dict[table["fname"]]
133
+ table.update(feature_dict)
134
+
135
+ if self.creatinine or self.bilirubin or self.platelets or self.wbc:
136
+ self.task_itemids = {
137
+ "creatinine": {
138
+ "fname": "lab" + self.ext,
139
+ "timestamp": "labresultoffset",
140
+ "timeoffsetunit": "min",
141
+ "exclude": [
142
+ "labtypeid",
143
+ "labresulttext",
144
+ "labmeasurenamesystem",
145
+ "labmeasurenameinterface",
146
+ "labresultrevisedoffset",
147
+ ],
148
+ "code": ["labname"],
149
+ "value": ["labresult"],
150
+ "itemid": ["creatinine"],
151
+ },
152
+ "bilirubin": {
153
+ "fname": "lab" + self.ext,
154
+ "timestamp": "labresultoffset",
155
+ "timeoffsetunit": "min",
156
+ "exclude": [
157
+ "labtypeid",
158
+ "labresulttext",
159
+ "labmeasurenamesystem",
160
+ "labmeasurenameinterface",
161
+ "labresultrevisedoffset",
162
+ ],
163
+ "code": ["labname"],
164
+ "value": ["labresult"],
165
+ "itemid": ["total bilirubin"],
166
+ },
167
+ "platelets": {
168
+ "fname": "lab" + self.ext,
169
+ "timestamp": "labresultoffset",
170
+ "timeoffsetunit": "min",
171
+ "exclude": [
172
+ "labtypeid",
173
+ "labresulttext",
174
+ "labmeasurenamesystem",
175
+ "labmeasurenameinterface",
176
+ "labresultrevisedoffset",
177
+ ],
178
+ "code": ["labname"],
179
+ "value": ["labresult"],
180
+ "itemid": ["platelets x 1000"],
181
+ },
182
+ "wbc": {
183
+ "fname": "lab" + self.ext,
184
+ "timestamp": "labresultoffset",
185
+ "timeoffsetunit": "min",
186
+ "exclude": [
187
+ "labtypeid",
188
+ "labresulttext",
189
+ "labmeasurenamesystem",
190
+ "labmeasurenameinterface",
191
+ "labresultrevisedoffset",
192
+ ],
193
+ "code": ["labname"],
194
+ "value": ["labresult"],
195
+ "itemid": ["WBC x 1000"],
196
+ },
197
+ "dialysis": {
198
+ "fname": "intakeOutput" + self.ext,
199
+ "timestamp": "intakeoutputoffset",
200
+ "timeoffsetunit": "min",
201
+ "exclude": [
202
+ "intakeoutputid",
203
+ "intaketotal",
204
+ "outputtotal",
205
+ "nettotal",
206
+ "intakeoutputentryoffset",
207
+ ],
208
+ "code": ["dialysistotal"],
209
+ "value": [],
210
+ "itemid": [],
211
+ },
212
+ }
213
+
214
+ self.disch_map_dict = {
215
+ "Home": "Home",
216
+ "IN_ICU_MORTALITY": "IN_ICU_MORTALITY",
217
+ "Nursing Home": "Other",
218
+ "Other": "Other",
219
+ "Other External": "Other",
220
+ "Other Hospital": "Other",
221
+ "Rehabilitation": "Rehabilitation",
222
+ "Skilled Nursing Facility": "Skilled Nursing Facility",
223
+ "Death": "Death",
224
+ }
225
+
226
+ self._icustay_key = "patientunitstayid"
227
+ self._hadm_key = "patienthealthsystemstayid"
228
+ self._patient_key = "uniquepid"
229
+
230
+ self._determine_first_icu = "unitvisitnumber"
231
+
232
+ def build_cohorts(self, cached=False):
233
+ icustays = pd.read_csv(os.path.join(self.data_dir, self.icustay_fname))
234
+
235
+ icustays = self.make_compatible(icustays)
236
+ self.icustays = icustays
237
+
238
+ cohorts = super().build_cohorts(icustays, cached=cached)
239
+
240
+ return cohorts
241
+
242
+ def prepare_tasks(self, cohorts, spark, cached=False):
243
+ if cohorts is None and cached:
244
+ labeled_cohorts = self.load_from_cache(self.ehr_name + ".cohorts.labeled")
245
+ if labeled_cohorts is not None:
246
+ return labeled_cohorts
247
+
248
+ labeled_cohorts = super().prepare_tasks(cohorts, spark, cached)
249
+
250
+ if self.diagnosis:
251
+ logger.info("Start labeling cohorts for diagnosis prediction.")
252
+
253
+ str2cat = self.make_dx_mapping()
254
+ dx = pd.read_csv(os.path.join(self.data_dir, self.diagnosis_fname))
255
+ dx = dx.merge(cohorts[[self.icustay_key, self.hadm_key]], on=self.icustay_key)
256
+ dx["diagnosis"] = dx["diagnosisstring"].map(lambda x: str2cat.get(x, -1))
257
+ # Ignore Rare Class(14)
258
+ dx = dx[(dx["diagnosis"] != -1) & (dx["diagnosis"] != 14)]
259
+ dx.loc[dx["diagnosis"] >= 14, "diagnosis"] -= 1
260
+ dx = dx.groupby(self.hadm_key)["diagnosis"].agg(lambda x: list(set(x))).to_frame()
261
+
262
+ labeled_cohorts = labeled_cohorts.merge(dx, on=self.hadm_key, how="left")
263
+ labeled_cohorts["diagnosis"] = labeled_cohorts["diagnosis"].apply(
264
+ lambda x: [] if type(x) is not list else x
265
+ )
266
+ # NaN case in diagnosis -> []
267
+
268
+ logger.info("Done preparing diagnosis prediction for the given cohorts")
269
+
270
+ self.save_to_cache(labeled_cohorts, self.ehr_name + ".cohorts.labeled")
271
+
272
+ if self.bilirubin or self.platelets or self.creatinine or self.wbc:
273
+ logger.info("Start labeling cohorts for clinical task prediction.")
274
+
275
+ labeled_cohorts = spark.createDataFrame(labeled_cohorts)
276
+
277
+ if self.bilirubin:
278
+ labeled_cohorts = self.clinical_task(labeled_cohorts, "bilirubin", spark)
279
+
280
+ if self.platelets:
281
+ labeled_cohorts = self.clinical_task(labeled_cohorts, "platelets", spark)
282
+
283
+ if self.creatinine:
284
+ labeled_cohorts = self.clinical_task(labeled_cohorts, "creatinine", spark)
285
+
286
+ if self.wbc:
287
+ labeled_cohorts = self.clinical_task(labeled_cohorts, "wbc", spark)
288
+
289
+ # self.save_to_cache(labeled_cohorts, self.ehr_name + ".cohorts.labeled.clinical_tasks")
290
+
291
+ logger.info("Done preparing clinical task prediction for the given cohorts")
292
+
293
+ return labeled_cohorts
294
+
295
+ def make_compatible(self, icustays):
296
+ icustays.loc[:, "LOS"] = icustays["unitdischargeoffset"] / 60 / 24
297
+ icustays.dropna(subset=["age"], inplace=True)
298
+ icustays["AGE"] = icustays["age"].replace("> 89", 300).astype(int)
299
+
300
+ # hacks for compatibility with other ehrs
301
+ icustays["INTIME"] = 0
302
+ icustays.rename(columns={"unitdischargeoffset": "OUTTIME"}, inplace=True)
303
+ # DEATHTIME
304
+ # icustays["DEATHTIME"] = np.nan
305
+ # is_discharged_in_icu = icustays["unitdischargestatus"] == "Expired"
306
+ # icustays.loc[is_discharged_in_icu, "DEATHTIME"] = (
307
+ # icustays.loc[is_discharged_in_icu, "OUTTIME"]
308
+ # )
309
+ # is_discharged_in_hos = (
310
+ # (icustays["unitdischargestatus"] != "Expired")
311
+ # & (icustays["hospitaldischargestatus"] == "Expired")
312
+ # )
313
+ # icustays.loc[is_discharged_in_hos, "DEATHTIME"] = (
314
+ # icustays.loc[is_discharged_in_hos, "OUTTIME"]
315
+ # ) + 1
316
+
317
+ icustays.rename(columns={"hospitaldischargeoffset": "DISCHTIME"}, inplace=True)
318
+
319
+ icustays["IN_ICU_MORTALITY"] = icustays["unitdischargestatus"] == "Expired"
320
+ icustays["hospitaldischargelocation"] = icustays["hospitaldischargelocation"].map(self.disch_map_dict)
321
+ icustays.rename(columns={"hospitaldischargelocation": "HOS_DISCHARGE_LOCATION"}, inplace=True)
322
+
323
+ return icustays
324
+
325
+ def make_dx_mapping(self):
326
+ diagnosis = pd.read_csv(os.path.join(self.data_dir, self.diagnosis_fname))
327
+ ccs_dx = pd.read_csv(self.ccs_path)
328
+ gem = pd.read_csv(self.gem_path)
329
+
330
+ diagnosis = diagnosis[["diagnosisstring", "icd9code"]]
331
+
332
+ # 1 to 1 matching btw str and code
333
+ # STR: diagnosisstring, CODE: icd9/10 code, CAT:category
334
+ # 1 str -> multiple code -> one cat
335
+
336
+ # 1. make str -> code dictonary
337
+ str2code = diagnosis.dropna(subset=["icd9code"])
338
+ str2code = str2code.groupby("diagnosisstring").first().reset_index()
339
+ str2code["icd9code"] = str2code["icd9code"].str.split(",")
340
+ str2code = str2code.explode("icd9code")
341
+ str2code["icd9code"] = str2code["icd9code"].str.replace(".", "", regex=False)
342
+ # str2code = dict(zip(notnull_dx["diagnosisstring"], notnull_dx["icd9code"]))
343
+ # 이거 하면 dxstring duplicated 자동 제거됨 ->x
344
+
345
+ ccs_dx["'ICD-9-CM CODE'"] = ccs_dx["'ICD-9-CM CODE'"].str[1:-1].str.strip()
346
+ ccs_dx["'CCS LVL 1'"] = ccs_dx["'CCS LVL 1'"].str[1:-1].astype(int) - 1
347
+ icd2cat = dict(zip(ccs_dx["'ICD-9-CM CODE'"], ccs_dx["'CCS LVL 1'"]))
348
+
349
+ # 2. if code is not icd9, convert it to icd9
350
+ str2code_icd10 = str2code[str2code["icd9code"].isin(icd2cat.keys())]
351
+
352
+ map_cms = dict(zip(gem["icd10cm"], gem["icd9cm"]))
353
+ map_manual = dict.fromkeys(set(str2code_icd10["icd9code"]) - set(gem["icd10cm"]), "NaN")
354
+
355
+ for code_10 in map_manual:
356
+ for i in range(len(code_10), 0, -1):
357
+ tgt_10 = code_10[:i]
358
+ if tgt_10 in gem["icd10cm"]:
359
+ tgt_9 = gem[gem["icd10cm"].str.contains(tgt_10)]["icd9cm"].mode().iloc[0]
360
+ map_manual[code_10] = tgt_9
361
+ break
362
+ icd102icd9 = {**map_cms, **map_manual}
363
+
364
+ # 3. Convert Available Strings to category
365
+ str2cat = {}
366
+ for _, row in str2code.iterrows():
367
+ k, v = row
368
+ if v in icd2cat.keys():
369
+ cat = icd2cat[v]
370
+ if k in str2cat.keys() and str2cat[k] != cat:
371
+ logger.warning(f"{k} has multiple categories{cat, str2cat[k]}")
372
+ str2cat[k] = icd2cat[v]
373
+ elif v in icd102icd9.keys():
374
+ cat = icd2cat[icd102icd9[v]]
375
+ if k in str2cat.keys() and str2cat[k] != cat:
376
+ logger.warning(f"{k} has multiple categories{cat, str2cat[k]}")
377
+ str2cat[k] = icd2cat[icd102icd9[v]]
378
+
379
+ # 4. If no available category by mapping(~25%), use diagnosisstring hierarchy
380
+
381
+ # Make tree structure
382
+ tree = treelib.Tree()
383
+ tree.create_node("root", "root")
384
+ for dx, cat in str2cat.items():
385
+ dx = dx.split("|")
386
+ if not tree.contains(dx[0]):
387
+ tree.create_node(-1, dx[0], parent="root")
388
+ for i in range(2, len(dx)):
389
+ if not tree.contains("|".join(dx[:i])):
390
+ tree.create_node(-1, "|".join(dx[:i]), parent="|".join(dx[: i - 1]))
391
+ if not tree.contains("|".join(dx)):
392
+ tree.create_node(cat, "|".join(dx), parent="|".join(dx[:-1]))
393
+
394
+ # Update non-leaf nodes with majority vote
395
+ nid_list = list(tree.expand_tree(mode=treelib.Tree.DEPTH))
396
+ nid_list.reverse()
397
+ for nid in nid_list:
398
+ if tree.get_node(nid).is_leaf():
399
+ continue
400
+ elif tree.get_node(nid).tag == -1:
401
+ tree.get_node(nid).tag = Counter([child.tag for child in tree.children(nid)]).most_common(1)[
402
+ 0
403
+ ][0]
404
+
405
+ # Evaluate dxs without category
406
+ unmatched_dxs = set(diagnosis["diagnosisstring"]) - set(str2cat.keys())
407
+ for dx in unmatched_dxs:
408
+ dx = dx.split("|")
409
+ # Do not go to root level(can add noise)
410
+ for i in range(len(dx) - 1, 1, -1):
411
+ if tree.contains("|".join(dx[:i])):
412
+ str2cat["|".join(dx)] = tree.get_node("|".join(dx[:i])).tag
413
+ break
414
+
415
+ return str2cat
416
+
417
+ def clinical_task(self, cohorts, task, spark):
418
+ fname = self.task_itemids[task]["fname"]
419
+ timestamp = self.task_itemids[task]["timestamp"]
420
+ excludes = self.task_itemids[task]["exclude"]
421
+ code = self.task_itemids[task]["code"][0]
422
+ value = self.task_itemids[task]["value"][0]
423
+ itemid = self.task_itemids[task]["itemid"]
424
+
425
+ table = spark.read.csv(os.path.join(self.data_dir, fname), header=True)
426
+ table = table.drop(*excludes)
427
+ table = table.filter(F.col(code).isin(itemid)).filter(F.col(value).isNotNull())
428
+
429
+ merge = cohorts.join(table, on=self.icustay_key, how="inner")
430
+
431
+ if task == "creatinine":
432
+ patient = spark.read.csv(os.path.join(self.data_dir, self._icustay_fname), header=True)
433
+ patient = patient.select(*[self.patient_key, self.icustay_key, self._hadm_key]) # icuunit intime
434
+ multi_hosp = (
435
+ patient.groupBy(self.patient_key)
436
+ .agg(F.count(self._hadm_key).alias("count"))
437
+ .filter(F.col("count") > 1)
438
+ .select(self.patient_key)
439
+ .rdd.flatMap(lambda row: row)
440
+ .collect()
441
+ )
442
+ # multiple hosp
443
+
444
+ dialysis_tables = self.task_itemids["dialysis"]["fname"] # Only treatment for dialysis
445
+ dialysis_code = self.task_itemids["dialysis"]["code"][0]
446
+ excludes = self.task_itemids["dialysis"]["exclude"]
447
+
448
+ io = spark.read.csv(os.path.join(self.data_dir, dialysis_tables), header=True)
449
+ io = io.drop(*excludes)
450
+
451
+ io_dialysis = io.filter(F.col(dialysis_code) != 0)
452
+ io_dialysis = io_dialysis.join(patient, on=self.icustay_key, how="left")
453
+
454
+ dialysis_multihosp = (
455
+ io_dialysis.filter(F.col(self.patient_key).isin(multi_hosp))
456
+ .select(self.patient_key)
457
+ .rdd.flatMap(lambda row: row)
458
+ .collect()
459
+ )
460
+
461
+ io_dialysis = io_dialysis.drop(self.patient_key)
462
+
463
+ def dialysis_time(table, timecolumn):
464
+ return table.withColumn("_DIALYSIS_TIME", F.col(timecolumn)).select(
465
+ self.icustay_key, "_DIALYSIS_TIME"
466
+ )
467
+
468
+ io_dialysis = dialysis_time(io_dialysis, self.task_itemids["dialysis"]["timestamp"])
469
+ io_dialysis = io_dialysis.groupBy(self.icustay_key).agg(
470
+ F.min("_DIALYSIS_TIME").alias("_DIALYSIS_TIME")
471
+ )
472
+ io_dialysis = io_dialysis.select([self.icustay_key, "_DIALYSIS_TIME"])
473
+ merge = merge.join(io_dialysis, on=self.icustay_key, how="left")
474
+ merge = merge.filter(F.isnull("_DIALYSIS_TIME") | (F.col("_DIALYSIS_TIME") > F.col(timestamp)))
475
+ merge = merge.drop("_DIALYSIS_TIME")
476
+
477
+ # For Creatinine task, eliminate icus if patient went through dialysis treatment
478
+ # before (obs_size + pred_size) timestamp
479
+
480
+ # Cohort with events within (obs_size + gap_size) - (obs_size + pred_size)
481
+ merge = merge.filter(((self.obs_size + self.gap_size) * 60) <= F.col(timestamp)).filter(
482
+ ((self.obs_size + self.pred_size) * 60) >= F.col(timestamp)
483
+ )
484
+
485
+ # Average value of events
486
+ value_agg = merge.groupBy(self.icustay_key).agg(
487
+ F.mean(value).alias("avg_value")
488
+ ) # TODO: mean/min/max?
489
+
490
+ # Labeling
491
+ if task == "bilirubin":
492
+ value_agg = value_agg.withColumn(
493
+ task,
494
+ F.when(value_agg.avg_value < 1.2, 0)
495
+ .when((value_agg.avg_value >= 1.2) & (value_agg.avg_value < 2.0), 1)
496
+ .when((value_agg.avg_value >= 2.0) & (value_agg.avg_value < 6.0), 2)
497
+ .when((value_agg.avg_value >= 6.0) & (value_agg.avg_value < 12.0), 3)
498
+ .when(value_agg.avg_value >= 12.0, 4),
499
+ )
500
+ elif task == "platelets":
501
+ value_agg = value_agg.withColumn(
502
+ task,
503
+ F.when(value_agg.avg_value >= 150, 0)
504
+ .when((value_agg.avg_value >= 100) & (value_agg.avg_value < 150), 1)
505
+ .when((value_agg.avg_value >= 50) & (value_agg.avg_value < 100), 2)
506
+ .when((value_agg.avg_value >= 20) & (value_agg.avg_value < 50), 3)
507
+ .when(value_agg.avg_value < 20, 4),
508
+ )
509
+
510
+ elif task == "creatinine":
511
+ value_agg = value_agg.join(
512
+ patient.select([self.patient_key, self.icustay_key]), on=self.icustay_key, how="left"
513
+ )
514
+ value_agg = value_agg.withColumn(
515
+ task,
516
+ F.when(value_agg.avg_value < 1.2, 0)
517
+ .when((value_agg.avg_value >= 1.2) & (value_agg.avg_value < 2.0), 1)
518
+ .when((value_agg.avg_value >= 2.0) & (value_agg.avg_value < 3.5), 2)
519
+ .when((value_agg.avg_value >= 3.5) & (value_agg.avg_value < 5), 3)
520
+ .when(value_agg.avg_value >= 5, 4),
521
+ )
522
+ value_agg = value_agg.filter(~F.col(self.patient_key).isin(dialysis_multihosp))
523
+ value_agg = value_agg.drop(self.patient_key)
524
+
525
+ elif task == "wbc":
526
+ value_agg = value_agg.withColumn(
527
+ task,
528
+ F.when(value_agg.avg_value < 4, 0)
529
+ .when((value_agg.avg_value >= 4) & (value_agg.avg_value <= 12), 1)
530
+ .when((value_agg.avg_value > 12), 2),
531
+ )
532
+
533
+ cohorts = cohorts.join(value_agg.select(self.icustay_key, task), on=self.icustay_key, how="left")
534
+
535
+ return cohorts
536
+
537
+ def infer_data_extension(self) -> str:
538
+ if len(glob.glob(os.path.join(self.data_dir, "*.csv.gz"))) == 31:
539
+ ext = ".csv.gz"
540
+ elif len(glob.glob(os.path.join(self.data_dir, "*.csv"))) == 31:
541
+ ext = ".csv"
542
+ else:
543
+ raise AssertionError(
544
+ "Provided data directory is not correct. Please check if --data is correct. "
545
+ "--data: {}".format(self.data_dir)
546
+ )
547
+
548
+ logger.info("Data extension is set to '{}'".format(ext))
549
+
550
+ return ext