genhpf 1.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. genhpf/__init__.py +9 -0
  2. genhpf/configs/__init__.py +23 -0
  3. genhpf/configs/config.yaml +8 -0
  4. genhpf/configs/configs.py +240 -0
  5. genhpf/configs/constants.py +29 -0
  6. genhpf/configs/initialize.py +58 -0
  7. genhpf/configs/utils.py +29 -0
  8. genhpf/criterions/__init__.py +74 -0
  9. genhpf/criterions/binary_cross_entropy.py +114 -0
  10. genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
  11. genhpf/criterions/criterion.py +87 -0
  12. genhpf/criterions/cross_entropy.py +202 -0
  13. genhpf/criterions/multi_task_criterion.py +177 -0
  14. genhpf/criterions/simclr_criterion.py +84 -0
  15. genhpf/criterions/wav2vec2_criterion.py +130 -0
  16. genhpf/datasets/__init__.py +84 -0
  17. genhpf/datasets/dataset.py +109 -0
  18. genhpf/datasets/genhpf_dataset.py +451 -0
  19. genhpf/datasets/meds_dataset.py +232 -0
  20. genhpf/loggings/__init__.py +0 -0
  21. genhpf/loggings/meters.py +374 -0
  22. genhpf/loggings/metrics.py +155 -0
  23. genhpf/loggings/progress_bar.py +445 -0
  24. genhpf/models/__init__.py +73 -0
  25. genhpf/models/genhpf.py +244 -0
  26. genhpf/models/genhpf_mlm.py +64 -0
  27. genhpf/models/genhpf_predictor.py +73 -0
  28. genhpf/models/genhpf_simclr.py +58 -0
  29. genhpf/models/genhpf_wav2vec2.py +304 -0
  30. genhpf/modules/__init__.py +15 -0
  31. genhpf/modules/gather_layer.py +23 -0
  32. genhpf/modules/grad_multiply.py +12 -0
  33. genhpf/modules/gumbel_vector_quantizer.py +204 -0
  34. genhpf/modules/identity_layer.py +8 -0
  35. genhpf/modules/layer_norm.py +27 -0
  36. genhpf/modules/positional_encoding.py +24 -0
  37. genhpf/scripts/__init__.py +0 -0
  38. genhpf/scripts/preprocess/__init__.py +0 -0
  39. genhpf/scripts/preprocess/genhpf/README.md +75 -0
  40. genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
  41. genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
  42. genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
  43. genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
  44. genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
  45. genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
  46. genhpf/scripts/preprocess/genhpf/main.py +175 -0
  47. genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
  48. genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
  49. genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
  50. genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
  51. genhpf/scripts/preprocess/manifest.py +83 -0
  52. genhpf/scripts/preprocess/preprocess_meds.py +674 -0
  53. genhpf/scripts/test.py +264 -0
  54. genhpf/scripts/train.py +365 -0
  55. genhpf/trainer.py +370 -0
  56. genhpf/utils/checkpoint_utils.py +171 -0
  57. genhpf/utils/data_utils.py +130 -0
  58. genhpf/utils/distributed_utils.py +497 -0
  59. genhpf/utils/file_io.py +170 -0
  60. genhpf/utils/pdb.py +38 -0
  61. genhpf/utils/utils.py +204 -0
  62. genhpf-1.0.11.dist-info/LICENSE +21 -0
  63. genhpf-1.0.11.dist-info/METADATA +202 -0
  64. genhpf-1.0.11.dist-info/RECORD +67 -0
  65. genhpf-1.0.11.dist-info/WHEEL +5 -0
  66. genhpf-1.0.11.dist-info/entry_points.txt +6 -0
  67. genhpf-1.0.11.dist-info/top_level.txt +1 -0
@@ -0,0 +1,619 @@
1
+ import glob
2
+ import logging
3
+ import os
4
+
5
+ import pandas as pd
6
+ import pyspark.sql.functions as F
7
+ from ehrs import EHR, register_ehr
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ @register_ehr("mimiciv")
13
+ class MIMICIV(EHR):
14
+ def __init__(self, cfg):
15
+ super().__init__(cfg)
16
+
17
+ self.ehr_name = "mimiciv"
18
+
19
+ if self.data_dir is None:
20
+ self.data_dir = os.path.join(self.cache_dir, self.ehr_name)
21
+
22
+ if not os.path.exists(self.data_dir):
23
+ logger.info(
24
+ "Data is not found so try to download from the internet. "
25
+ "Note that this is a restricted-access resource. "
26
+ "Please log in to physionet.org with a credentialed user."
27
+ )
28
+ self.download_ehr_from_url(url="https://physionet.org/files/mimiciv/2.0/", dest=self.data_dir)
29
+
30
+ logger.info("Data directory is set to {}".format(self.data_dir))
31
+
32
+ if self.ccs_path is None:
33
+ self.ccs_path = os.path.join(self.cache_dir, "ccs_multi_dx_tool_2015.csv")
34
+
35
+ if not os.path.exists(self.ccs_path):
36
+ logger.info("`ccs_multi_dx_tool_2015.csv` is not found so try to download from the internet.")
37
+ self.download_ccs_from_url(self.cache_dir)
38
+
39
+ if self.gem_path is None:
40
+ self.gem_path = os.path.join(self.cache_dir, "icd10cmtoicd9gem.csv")
41
+
42
+ if not os.path.exists(self.gem_path):
43
+ logger.info("`icd10cmtoicd9gem.csv` is not found so try to download from the internet.")
44
+ self.download_icdgem_from_url(self.cache_dir)
45
+
46
+ if self.ext is None:
47
+ self.ext = self.infer_data_extension()
48
+
49
+ self._icustay_fname = "icu/icustays" + self.ext
50
+ self._patient_fname = "hosp/patients" + self.ext
51
+ self._admission_fname = "hosp/admissions" + self.ext
52
+ self._diagnosis_fname = "hosp/diagnoses_icd" + self.ext
53
+
54
+ self.tables = [
55
+ {
56
+ "fname": "hosp/labevents" + self.ext,
57
+ "timestamp": "charttime",
58
+ "timeoffsetunit": "abs",
59
+ "exclude": [
60
+ "labevent_id",
61
+ "storetime",
62
+ "subject_id",
63
+ "specimen_id",
64
+ "order_provider_id", # MIMIC-IV-2.2V added this column
65
+ ],
66
+ "code": ["itemid"],
67
+ "desc": ["hosp/d_labitems" + self.ext],
68
+ "desc_key": ["label"],
69
+ },
70
+ {
71
+ "fname": "hosp/prescriptions" + self.ext,
72
+ "timestamp": "starttime",
73
+ "timeoffsetunit": "abs",
74
+ "exclude": [
75
+ "gsn",
76
+ "ndc",
77
+ "subject_id",
78
+ "pharmacy_id",
79
+ "poe_id",
80
+ "poe_seq",
81
+ "formulary_drug_cd",
82
+ "stoptime",
83
+ "order_provider_id", # MIMIC-IV-2.2V added this column
84
+ ],
85
+ },
86
+ {
87
+ "fname": "icu/inputevents" + self.ext,
88
+ "timestamp": "starttime",
89
+ "timeoffsetunit": "abs",
90
+ "exclude": [
91
+ "endtime",
92
+ "storetime",
93
+ "orderid",
94
+ "linkorderid",
95
+ "subject_id",
96
+ "continueinnextdept",
97
+ "statusdescription",
98
+ ],
99
+ "code": ["itemid"],
100
+ "desc": ["icu/d_items" + self.ext],
101
+ "desc_key": ["label"],
102
+ },
103
+ ]
104
+
105
+ if self.feature == "select":
106
+ extra_exclude_feature_dict = {
107
+ "hosp/labevents"
108
+ + self.ext: ["value", "ref_range_lower", "ref_range_upper", "flag", "priority", "comments"],
109
+ "hosp/prescriptions"
110
+ + self.ext: [
111
+ "drug_type",
112
+ "form_rx",
113
+ "form_val_disp",
114
+ "form_unit_disp",
115
+ "doses_per_24_hrs",
116
+ "route",
117
+ ],
118
+ "icu/inputevents"
119
+ + self.ext: [
120
+ "amount",
121
+ "amountuom",
122
+ "ordercategoryname",
123
+ "secondaryordercategoryname",
124
+ "ordercomponenttypedescription",
125
+ "ordercategorydescription",
126
+ "patientweight",
127
+ "totalamount",
128
+ "totalamountuom",
129
+ "isopenbag",
130
+ "originalamount",
131
+ "originalrate",
132
+ ],
133
+ }
134
+
135
+ for table in self.tables:
136
+ if table["fname"] in extra_exclude_feature_dict.keys():
137
+ exclude_target_list = extra_exclude_feature_dict[table["fname"]]
138
+ table["exclude"].extend(exclude_target_list)
139
+
140
+ if self.emb_type == "codebase":
141
+ feature_types_for_codebase_emb_dict = {
142
+ "hosp/labevents"
143
+ + self.ext: {
144
+ "numeric_feat": ["value", "valuenum", "ref_range_upper", "ref_range_lower"],
145
+ "categorical_feat": [],
146
+ "code_feat": ["itemid"],
147
+ },
148
+ "hosp/prescriptions"
149
+ + self.ext: {
150
+ "numeric_feat": ["dose_val_rx", "form_val_disp"],
151
+ "categorical_feat": ["doses_per_24_hrs"],
152
+ "code_feat": ["drug"],
153
+ },
154
+ "icu/inputevents"
155
+ + self.ext: {
156
+ "numeric_feat": [
157
+ "amount",
158
+ "rate",
159
+ "patientweight",
160
+ "totalamount",
161
+ "originalamount",
162
+ "originalrate",
163
+ ],
164
+ "categorical_feat": ["isopenbag", "continueinnextdept", "cancelreason", "valuecounts"],
165
+ "code_feat": ["itemid"],
166
+ },
167
+ }
168
+
169
+ for table in self.tables:
170
+ if table["fname"] in feature_types_for_codebase_emb_dict.keys():
171
+ feature_dict = feature_types_for_codebase_emb_dict[table["fname"]]
172
+ table.update(feature_dict)
173
+
174
+ if self.creatinine or self.bilirubin or self.platelets or self.wbc:
175
+ self.task_itemids = {
176
+ "creatinine": {
177
+ "fname": "hosp/labevents" + self.ext,
178
+ "timestamp": "charttime",
179
+ "timeoffsetunit": "abs",
180
+ "exclude": [
181
+ "labevent_id",
182
+ "subject_id",
183
+ "specimen_id",
184
+ "storetime",
185
+ "value",
186
+ "valueuom",
187
+ "ref_range_lower",
188
+ "ref_range_upper",
189
+ "flag",
190
+ "priority",
191
+ "comments",
192
+ ],
193
+ "code": ["itemid"],
194
+ "value": ["valuenum"],
195
+ "itemid": [50912],
196
+ },
197
+ "bilirubin": {
198
+ "fname": "hosp/labevents" + self.ext,
199
+ "timestamp": "charttime",
200
+ "timeoffsetunit": "abs",
201
+ "exclude": [
202
+ "labevent_id",
203
+ "subject_id",
204
+ "specimen_id",
205
+ "storetime",
206
+ "value",
207
+ "valueuom",
208
+ "ref_range_lower",
209
+ "ref_range_upper",
210
+ "flag",
211
+ "priority",
212
+ "comments",
213
+ ],
214
+ "code": ["itemid"],
215
+ "value": ["valuenum"],
216
+ "itemid": [50885],
217
+ },
218
+ "platelets": {
219
+ "fname": "hosp/labevents" + self.ext,
220
+ "timestamp": "charttime",
221
+ "timeoffsetunit": "abs",
222
+ "exclude": [
223
+ "labevent_id",
224
+ "subject_id",
225
+ "specimen_id",
226
+ "storetime",
227
+ "value",
228
+ "valueuom",
229
+ "ref_range_lower",
230
+ "ref_range_upper",
231
+ "flag",
232
+ "priority",
233
+ "comments",
234
+ ],
235
+ "code": ["itemid"],
236
+ "value": ["valuenum"],
237
+ "itemid": [51265],
238
+ },
239
+ "wbc": {
240
+ "fname": "hosp/labevents" + self.ext,
241
+ "timestamp": "charttime",
242
+ "timeoffsetunit": "abs",
243
+ "exclude": [
244
+ "labevent_id",
245
+ "subject_id",
246
+ "specimen_id",
247
+ "storetime",
248
+ "value",
249
+ "valueuom",
250
+ "ref_range_lower",
251
+ "ref_range_upper",
252
+ "flag",
253
+ "priority",
254
+ "comments",
255
+ ],
256
+ "code": ["itemid"],
257
+ "value": ["valuenum"],
258
+ "itemid": [51300, 51301, 51755],
259
+ },
260
+ "dialysis": {
261
+ "tables": {
262
+ "chartevents": {
263
+ "fname": "icu/chartevents" + self.ext,
264
+ "timestamp": "charttime",
265
+ "timeoffsetunit": "abs",
266
+ "include": ["subject_id", "itemid", "value", "charttime"],
267
+ "itemid": {
268
+ "ce": [
269
+ 226499,
270
+ 224154,
271
+ 225183,
272
+ 227438,
273
+ 224191,
274
+ 225806,
275
+ 225807,
276
+ 228004,
277
+ 228005,
278
+ 228006,
279
+ 224144,
280
+ 224145,
281
+ 224153,
282
+ 226457,
283
+ ]
284
+ },
285
+ },
286
+ "inputevents": {
287
+ "fname": "icu/inputevents" + self.ext,
288
+ "timestamp": "starttime",
289
+ "timeoffsetunit": "abs",
290
+ "include": ["subject_id", "itemid", "amount", "starttime"],
291
+ "itemid": {"ie": [227536, 227525]},
292
+ },
293
+ "procedureevents": {
294
+ "fname": "icu/procedureevents" + self.ext,
295
+ "timestamp": "starttime",
296
+ "timeoffsetunit": "abs",
297
+ "include": ["subject_id", "itemid", "value", "starttime"],
298
+ "itemid": {"pe": [225441, 225802, 225803, 225805, 225809, 225955]},
299
+ },
300
+ }
301
+ },
302
+ }
303
+
304
+ self.disch_map_dict = {
305
+ "ACUTE HOSPITAL": "Other",
306
+ "AGAINST ADVICE": "Other",
307
+ "ASSISTED LIVING": "Other",
308
+ "CHRONIC/LONG TERM ACUTE CARE": "Other",
309
+ "HEALTHCARE FACILITY": "Other",
310
+ "HOME": "Home",
311
+ "HOME HEALTH CARE": "Home",
312
+ "HOSPICE": "Other",
313
+ "IN_ICU_MORTALITY": "IN_ICU_MORTALITY",
314
+ "OTHER FACILITY": "Other",
315
+ "PSYCH FACILITY": "Other",
316
+ "REHAB": "Rehabilitation",
317
+ "SKILLED NURSING FACILITY": "Skilled Nursing Facility",
318
+ "Death": "Death",
319
+ }
320
+
321
+ self._icustay_key = "stay_id"
322
+ self._hadm_key = "hadm_id"
323
+ self._patient_key = "subject_id"
324
+
325
+ self._determine_first_icu = "INTIME"
326
+
327
+ def build_cohorts(self, cached=False):
328
+ icustays = pd.read_csv(os.path.join(self.data_dir, self.icustay_fname))
329
+
330
+ icustays = self.make_compatible(icustays)
331
+ self.icustays = icustays
332
+
333
+ cohorts = super().build_cohorts(icustays, cached=cached)
334
+
335
+ return cohorts
336
+
337
+ def prepare_tasks(self, cohorts, spark, cached=False):
338
+ if cached:
339
+ labeled_cohorts = self.load_from_cache(self.ehr_name + ".cohorts.labeled")
340
+ if labeled_cohorts is not None:
341
+ return labeled_cohorts
342
+
343
+ labeled_cohorts = super().prepare_tasks(cohorts, spark, cached)
344
+
345
+ if self.diagnosis:
346
+ logger.info("Start labeling cohorts for diagnosis prediction.")
347
+
348
+ # define diagnosis prediction task
349
+ diagnoses = pd.read_csv(os.path.join(self.data_dir, self.diagnosis_fname))
350
+
351
+ diagnoses = self.icd10toicd9(diagnoses)
352
+
353
+ ccs_dx = pd.read_csv(self.ccs_path)
354
+ ccs_dx["'ICD-9-CM CODE'"] = ccs_dx["'ICD-9-CM CODE'"].str[1:-1].str.strip()
355
+ ccs_dx["'CCS LVL 1'"] = ccs_dx["'CCS LVL 1'"].str[1:-1]
356
+ lvl1 = {x: int(y) - 1 for _, (x, y) in ccs_dx[["'ICD-9-CM CODE'", "'CCS LVL 1'"]].iterrows()}
357
+
358
+ diagnoses["diagnosis"] = diagnoses["icd_code_converted"].map(lvl1)
359
+
360
+ diagnoses = diagnoses[(diagnoses["diagnosis"].notnull()) & (diagnoses["diagnosis"] != 14)]
361
+ diagnoses.loc[diagnoses["diagnosis"] >= 14, "diagnosis"] -= 1
362
+ diagnoses = diagnoses.groupby(self.hadm_key)["diagnosis"].agg(lambda x: list(set(x))).to_frame()
363
+
364
+ labeled_cohorts = labeled_cohorts.merge(diagnoses, on=self.hadm_key, how="inner")
365
+
366
+ logger.info("Done preparing diagnosis prediction for the given cohorts")
367
+
368
+ self.save_to_cache(labeled_cohorts, self.ehr_name + ".cohorts.labeled")
369
+
370
+ if self.bilirubin or self.platelets or self.creatinine or self.wbc:
371
+ logger.info("Start labeling cohorts for clinical task prediction.")
372
+
373
+ labeled_cohorts = spark.createDataFrame(labeled_cohorts)
374
+
375
+ if self.bilirubin:
376
+ labeled_cohorts = self.clinical_task(labeled_cohorts, "bilirubin", spark)
377
+
378
+ if self.platelets:
379
+ labeled_cohorts = self.clinical_task(labeled_cohorts, "platelets", spark)
380
+
381
+ if self.creatinine:
382
+ labeled_cohorts = self.clinical_task(labeled_cohorts, "creatinine", spark)
383
+
384
+ if self.wbc:
385
+ labeled_cohorts = self.clinical_task(labeled_cohorts, "wbc", spark)
386
+
387
+ # labeled_cohorts = labeled_cohorts.toPandas()
388
+
389
+ # self.save_to_cache(labeled_cohorts, self.ehr_name + ".cohorts.labeled.clinical_tasks")
390
+
391
+ logger.info("Done preparing clinical task prediction for the given cohorts")
392
+
393
+ return labeled_cohorts
394
+
395
+ def make_compatible(self, icustays):
396
+ patients = pd.read_csv(os.path.join(self.data_dir, self.patient_fname))
397
+ admissions = pd.read_csv(os.path.join(self.data_dir, self.admission_fname))
398
+
399
+ # prepare icustays according to the appropriate format
400
+ icustays = icustays.rename(
401
+ columns={
402
+ "los": "LOS",
403
+ "intime": "INTIME",
404
+ "outtime": "OUTTIME",
405
+ }
406
+ )
407
+ admissions = admissions.rename(
408
+ columns={
409
+ "dischtime": "DISCHTIME",
410
+ }
411
+ )
412
+
413
+ icustays = icustays[icustays["first_careunit"] == icustays["last_careunit"]]
414
+ icustays["INTIME"] = pd.to_datetime(icustays["INTIME"], infer_datetime_format=True, utc=True)
415
+ icustays["OUTTIME"] = pd.to_datetime(icustays["OUTTIME"], infer_datetime_format=True, utc=True)
416
+
417
+ icustays = icustays.merge(patients, on="subject_id", how="left")
418
+ icustays["AGE"] = icustays["INTIME"].dt.year - icustays["anchor_year"] + icustays["anchor_age"]
419
+
420
+ icustays = icustays.merge(
421
+ admissions[[self.hadm_key, "discharge_location", "deathtime", "DISCHTIME"]],
422
+ how="left",
423
+ on=self.hadm_key,
424
+ )
425
+
426
+ icustays["discharge_location"].replace("DIED", "Death", inplace=True)
427
+ icustays["DISCHTIME"] = pd.to_datetime(icustays["DISCHTIME"], infer_datetime_format=True, utc=True)
428
+
429
+ icustays["IN_ICU_MORTALITY"] = (
430
+ (icustays["INTIME"] < icustays["DISCHTIME"])
431
+ & (icustays["DISCHTIME"] <= icustays["OUTTIME"])
432
+ & (icustays["discharge_location"] == "Death")
433
+ )
434
+ icustays["discharge_location"] = icustays["discharge_location"].map(self.disch_map_dict)
435
+ icustays.rename(columns={"discharge_location": "HOS_DISCHARGE_LOCATION"}, inplace=True)
436
+
437
+ icustays["DISCHTIME"] = (icustays["DISCHTIME"] - icustays["INTIME"]).dt.total_seconds() // 60
438
+ icustays["OUTTIME"] = (icustays["OUTTIME"] - icustays["INTIME"]).dt.total_seconds() // 60
439
+
440
+ return icustays
441
+
442
+ def icd10toicd9(self, dx):
443
+ gem = pd.read_csv(self.gem_path)
444
+ dx_icd_10 = dx[dx["icd_version"] == 10]["icd_code"]
445
+
446
+ unique_elem_no_map = set(dx_icd_10) - set(gem["icd10cm"])
447
+
448
+ map_cms = dict(zip(gem["icd10cm"], gem["icd9cm"]))
449
+ map_manual = dict.fromkeys(unique_elem_no_map, "NaN")
450
+
451
+ for code_10 in map_manual:
452
+ for i in range(len(code_10), 0, -1):
453
+ tgt_10 = code_10[:i]
454
+ if tgt_10 in gem["icd10cm"]:
455
+ tgt_9 = gem[gem["icd10cm"].str.contains(tgt_10)]["icd9cm"].mode().iloc[0]
456
+ map_manual[code_10] = tgt_9
457
+ break
458
+
459
+ def icd_convert(icd_version, icd_code):
460
+ if icd_version == 9:
461
+ return icd_code
462
+
463
+ elif icd_code in map_cms:
464
+ return map_cms[icd_code]
465
+
466
+ elif icd_code in map_manual:
467
+ return map_manual[icd_code]
468
+ else:
469
+ logger.warn("WRONG CODE: " + icd_code)
470
+
471
+ dx["icd_code_converted"] = dx.apply(lambda x: icd_convert(x["icd_version"], x["icd_code"]), axis=1)
472
+ return dx
473
+
474
+ def clinical_task(self, cohorts, task, spark):
475
+ fname = self.task_itemids[task]["fname"]
476
+ timestamp = self.task_itemids[task]["timestamp"]
477
+ excludes = self.task_itemids[task]["exclude"]
478
+ code = self.task_itemids[task]["code"][0]
479
+ value = self.task_itemids[task]["value"][0]
480
+ itemid = self.task_itemids[task]["itemid"]
481
+
482
+ table = spark.read.csv(os.path.join(self.data_dir, fname), header=True)
483
+ table = table.drop(*excludes)
484
+ table = table.filter(F.col(code).isin(itemid)).filter(F.col(value).isNotNull())
485
+
486
+ merge = cohorts.join(table, on=self.hadm_key, how="inner")
487
+ merge = merge.withColumn(timestamp, F.to_timestamp(timestamp))
488
+
489
+ # Filter Dialysis at here to use abs timestamp & agg by patient_key
490
+ # For Creatinine task, eliminate icus if patient went through dialysis treatment before
491
+ # (obs_size + pred_size / outtime) timestamp
492
+ # see https://github.com/MIT-LCP/mimic-code/blob/main/mimic-iv/concepts/treatment/rrt.sql
493
+ if task == "creatinine":
494
+ dialysis_tables = self.task_itemids["dialysis"]["tables"]
495
+
496
+ chartevents = spark.read.csv(
497
+ os.path.join(self.data_dir, "icu/chartevents" + self.ext), header=True
498
+ )
499
+ inputevents = spark.read.csv(
500
+ os.path.join(self.data_dir, "icu/inputevents" + self.ext), header=True
501
+ )
502
+ procedureevents = spark.read.csv(
503
+ os.path.join(self.data_dir, "icu/procedureevents" + self.ext), header=True
504
+ )
505
+
506
+ chartevents = chartevents.select(*dialysis_tables["chartevents"]["include"])
507
+ inputevents = inputevents.select(*dialysis_tables["inputevents"]["include"])
508
+ procedureevents = procedureevents.select(*dialysis_tables["procedureevents"]["include"])
509
+
510
+ # Filter dialysis related tables with dialysis condition #TODO: check dialysis condition
511
+ ce = chartevents.filter(
512
+ (
513
+ ((F.col("itemid") == 225965) & (F.col("value") == "In use"))
514
+ | (F.col("itemid").isin(dialysis_tables["chartevents"]["itemid"]["ce"]))
515
+ & F.col("value").isNotNull()
516
+ )
517
+ )
518
+ ie = inputevents.filter(
519
+ F.col("itemid").isin(dialysis_tables["inputevents"]["itemid"]["ie"])
520
+ ).filter(F.col("amount") > 0)
521
+ pe = procedureevents.filter(
522
+ F.col("itemid").isin(dialysis_tables["procedureevents"]["itemid"]["pe"])
523
+ ).filter(F.col("value").isNotNull())
524
+
525
+ # Extract Dialysis Times!
526
+ def dialysis_time(table, timecolumn):
527
+ return table.withColumn("_DIALYSIS_TIME", F.to_timestamp(timecolumn)).select(
528
+ self.patient_key, "_DIALYSIS_TIME"
529
+ )
530
+
531
+ ce, ie, pe = (
532
+ dialysis_time(ce, "charttime"),
533
+ dialysis_time(ie, "starttime"),
534
+ dialysis_time(pe, "starttime"),
535
+ )
536
+ dialysis = ce.union(ie).union(pe)
537
+ dialysis = dialysis.groupby(self.patient_key).agg(F.min("_DIALYSIS_TIME").alias("_DIALYSIS_TIME"))
538
+ merge = merge.join(dialysis, on=self.patient_key, how="left")
539
+ # Only leave events with no dialysis / before first dialysis
540
+ merge = merge.filter(F.isnull("_DIALYSIS_TIME") | (F.col("_DIALYSIS_TIME") > F.col(timestamp)))
541
+ merge = merge.drop("_DIALYSIS_TIME")
542
+
543
+ merge = merge.withColumn(
544
+ timestamp, F.round((F.col(timestamp).cast("long") - F.col("INTIME").cast("long")) / 60)
545
+ )
546
+
547
+ # Cohort with events within (obs_size + gap_size) - (obs_size + pred_size)
548
+ merge = merge.filter(((self.obs_size + self.gap_size) * 60) <= F.col(timestamp)).filter(
549
+ ((self.obs_size + self.pred_size) * 60) >= F.col(timestamp)
550
+ )
551
+
552
+ # Average value of events
553
+ value_agg = merge.groupBy(self.icustay_key).agg(
554
+ F.mean(value).alias("avg_value")
555
+ ) # TODO: mean/min/max?
556
+
557
+ # Labeling
558
+ if task == "bilirubin":
559
+ value_agg = value_agg.withColumn(
560
+ task,
561
+ F.when(value_agg.avg_value < 1.2, 0)
562
+ .when((value_agg.avg_value >= 1.2) & (value_agg.avg_value < 2.0), 1)
563
+ .when((value_agg.avg_value >= 2.0) & (value_agg.avg_value < 6.0), 2)
564
+ .when((value_agg.avg_value >= 6.0) & (value_agg.avg_value < 12.0), 3)
565
+ .when(value_agg.avg_value >= 12.0, 4),
566
+ )
567
+ elif task == "platelets":
568
+ value_agg = value_agg.withColumn(
569
+ task,
570
+ F.when(value_agg.avg_value >= 150, 0)
571
+ .when((value_agg.avg_value >= 100) & (value_agg.avg_value < 150), 1)
572
+ .when((value_agg.avg_value >= 50) & (value_agg.avg_value < 100), 2)
573
+ .when((value_agg.avg_value >= 20) & (value_agg.avg_value < 50), 3)
574
+ .when(value_agg.avg_value < 20, 4),
575
+ )
576
+
577
+ elif task == "creatinine":
578
+ value_agg = value_agg.withColumn(
579
+ task,
580
+ F.when(value_agg.avg_value < 1.2, 0)
581
+ .when((value_agg.avg_value >= 1.2) & (value_agg.avg_value < 2.0), 1)
582
+ .when((value_agg.avg_value >= 2.0) & (value_agg.avg_value < 3.5), 2)
583
+ .when((value_agg.avg_value >= 3.5) & (value_agg.avg_value < 5), 3)
584
+ .when(value_agg.avg_value >= 5, 4),
585
+ )
586
+
587
+ elif task == "wbc":
588
+ # NOTE: unit is mg/L
589
+ value_agg = value_agg.withColumn(
590
+ task,
591
+ F.when(value_agg.avg_value < 4, 0)
592
+ .when((value_agg.avg_value >= 4) & (value_agg.avg_value <= 12), 1)
593
+ .when((value_agg.avg_value > 12), 2),
594
+ )
595
+
596
+ cohorts = cohorts.join(value_agg.select(self.icustay_key, task), on=self.icustay_key, how="left")
597
+
598
+ return cohorts
599
+
600
+ def infer_data_extension(self) -> str:
601
+ if (
602
+ len(glob.glob(os.path.join(self.data_dir, "hosp", "*.csv.gz"))) == 22
603
+ or len(glob.glob(os.path.join(self.data_dir, "icu", "*.csv.gz"))) == 9
604
+ ):
605
+ ext = ".csv.gz"
606
+ elif (
607
+ len(glob.glob(os.path.join(self.data_dir, "hosp", "*.csv"))) == 22
608
+ or len(glob.glob(os.path.join(self.data_dir, "icu", "*.csv"))) == 9
609
+ ):
610
+ ext = ".csv"
611
+ else:
612
+ raise AssertionError(
613
+ "Provided data directory is not correct. Please check if --data is correct. "
614
+ "--data: {}".format(self.data_dir)
615
+ )
616
+
617
+ logger.info("Data extension is set to '{}'".format(ext))
618
+
619
+ return ext