genhpf 1.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. genhpf/__init__.py +9 -0
  2. genhpf/configs/__init__.py +23 -0
  3. genhpf/configs/config.yaml +8 -0
  4. genhpf/configs/configs.py +240 -0
  5. genhpf/configs/constants.py +29 -0
  6. genhpf/configs/initialize.py +58 -0
  7. genhpf/configs/utils.py +29 -0
  8. genhpf/criterions/__init__.py +74 -0
  9. genhpf/criterions/binary_cross_entropy.py +114 -0
  10. genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
  11. genhpf/criterions/criterion.py +87 -0
  12. genhpf/criterions/cross_entropy.py +202 -0
  13. genhpf/criterions/multi_task_criterion.py +177 -0
  14. genhpf/criterions/simclr_criterion.py +84 -0
  15. genhpf/criterions/wav2vec2_criterion.py +130 -0
  16. genhpf/datasets/__init__.py +84 -0
  17. genhpf/datasets/dataset.py +109 -0
  18. genhpf/datasets/genhpf_dataset.py +451 -0
  19. genhpf/datasets/meds_dataset.py +232 -0
  20. genhpf/loggings/__init__.py +0 -0
  21. genhpf/loggings/meters.py +374 -0
  22. genhpf/loggings/metrics.py +155 -0
  23. genhpf/loggings/progress_bar.py +445 -0
  24. genhpf/models/__init__.py +73 -0
  25. genhpf/models/genhpf.py +244 -0
  26. genhpf/models/genhpf_mlm.py +64 -0
  27. genhpf/models/genhpf_predictor.py +73 -0
  28. genhpf/models/genhpf_simclr.py +58 -0
  29. genhpf/models/genhpf_wav2vec2.py +304 -0
  30. genhpf/modules/__init__.py +15 -0
  31. genhpf/modules/gather_layer.py +23 -0
  32. genhpf/modules/grad_multiply.py +12 -0
  33. genhpf/modules/gumbel_vector_quantizer.py +204 -0
  34. genhpf/modules/identity_layer.py +8 -0
  35. genhpf/modules/layer_norm.py +27 -0
  36. genhpf/modules/positional_encoding.py +24 -0
  37. genhpf/scripts/__init__.py +0 -0
  38. genhpf/scripts/preprocess/__init__.py +0 -0
  39. genhpf/scripts/preprocess/genhpf/README.md +75 -0
  40. genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
  41. genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
  42. genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
  43. genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
  44. genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
  45. genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
  46. genhpf/scripts/preprocess/genhpf/main.py +175 -0
  47. genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
  48. genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
  49. genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
  50. genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
  51. genhpf/scripts/preprocess/manifest.py +83 -0
  52. genhpf/scripts/preprocess/preprocess_meds.py +674 -0
  53. genhpf/scripts/test.py +264 -0
  54. genhpf/scripts/train.py +365 -0
  55. genhpf/trainer.py +370 -0
  56. genhpf/utils/checkpoint_utils.py +171 -0
  57. genhpf/utils/data_utils.py +130 -0
  58. genhpf/utils/distributed_utils.py +497 -0
  59. genhpf/utils/file_io.py +170 -0
  60. genhpf/utils/pdb.py +38 -0
  61. genhpf/utils/utils.py +204 -0
  62. genhpf-1.0.11.dist-info/LICENSE +21 -0
  63. genhpf-1.0.11.dist-info/METADATA +202 -0
  64. genhpf-1.0.11.dist-info/RECORD +67 -0
  65. genhpf-1.0.11.dist-info/WHEEL +5 -0
  66. genhpf-1.0.11.dist-info/entry_points.txt +6 -0
  67. genhpf-1.0.11.dist-info/top_level.txt +1 -0
@@ -0,0 +1,839 @@
1
+ import glob
2
+ import logging
3
+ import os
4
+ from datetime import datetime
5
+
6
+ import pandas as pd
7
+ import pyspark.sql.functions as F
8
+ from ehrs import EHR, register_ehr
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ @register_ehr("mimiciii")
14
+ class MIMICIII(EHR):
15
+ def __init__(self, cfg):
16
+ super().__init__(cfg)
17
+
18
+ self.ehr_name = "mimiciii"
19
+
20
+ if self.data_dir is None:
21
+ self.data_dir = os.path.join(self.cache_dir, self.ehr_name)
22
+
23
+ if not os.path.exists(self.data_dir):
24
+ logger.info(
25
+ "Data is not found so try to download from the internet. "
26
+ "Note that this is a restricted-access resource. "
27
+ "Please log in to physionet.org with a credentialed user."
28
+ )
29
+ self.download_ehr_from_url(
30
+ url="https://physionet.org/files/mimiciii/1.4/", dest=self.data_dir
31
+ )
32
+
33
+ logger.info("Data directory is set to {}".format(self.data_dir))
34
+
35
+ if self.ccs_path is None:
36
+ self.ccs_path = os.path.join(self.cache_dir, "ccs_multi_dx_tool_2015.csv")
37
+
38
+ if not os.path.exists(self.ccs_path):
39
+ logger.info("`ccs_multi_dx_tool_2015.csv` is not found so try to download from the internet.")
40
+ self.download_ccs_from_url(self.cache_dir)
41
+
42
+ if self.ext is None:
43
+ self.ext = self.infer_data_extension()
44
+
45
+ # constants
46
+ self._icustay_fname = "ICUSTAYS" + self.ext
47
+ self._patient_fname = "PATIENTS" + self.ext
48
+ self._admission_fname = "ADMISSIONS" + self.ext
49
+ self._diagnosis_fname = "DIAGNOSES_ICD" + self.ext
50
+
51
+ # XXX more features? user choice?
52
+ self.tables = [
53
+ {
54
+ "fname": "LABEVENTS" + self.ext,
55
+ "timestamp": "CHARTTIME",
56
+ "timeoffsetunit": "abs",
57
+ "exclude": ["ROW_ID", "SUBJECT_ID"],
58
+ "code": ["ITEMID"],
59
+ "desc": ["D_LABITEMS" + self.ext],
60
+ "desc_key": ["LABEL"],
61
+ },
62
+ {
63
+ "fname": "PRESCRIPTIONS" + self.ext,
64
+ "timestamp": "STARTDATE",
65
+ "timeoffsetunit": "abs",
66
+ "exclude": ["ENDDATE", "GSN", "NDC", "ROW_ID", "SUBJECT_ID"],
67
+ },
68
+ {
69
+ "fname": "INPUTEVENTS_MV" + self.ext,
70
+ "timestamp": "STARTTIME",
71
+ "timeoffsetunit": "abs",
72
+ "exclude": [
73
+ "ENDTIME",
74
+ "STORETIME",
75
+ "CGID",
76
+ "ORDERID",
77
+ "LINKORDERID",
78
+ "ROW_ID",
79
+ "SUBJECT_ID",
80
+ "CONTINUEINNEXTDEPT",
81
+ "CANCELREASON",
82
+ "STATUSDESCRIPTION",
83
+ "COMMENTS_CANCELEDBY",
84
+ "COMMENTS_DATE",
85
+ ],
86
+ "code": ["ITEMID"],
87
+ "desc": ["D_ITEMS" + self.ext],
88
+ "desc_key": ["LABEL"],
89
+ },
90
+ {
91
+ "fname": "INPUTEVENTS_CV" + self.ext,
92
+ "timestamp": "CHARTTIME",
93
+ "timeoffsetunit": "abs",
94
+ "exclude": [
95
+ "STORETIME",
96
+ "CGID",
97
+ "ORDERID",
98
+ "LINKORDERID",
99
+ "ROW_ID",
100
+ "STOPPED",
101
+ "SUBJECT_ID",
102
+ ],
103
+ "code": ["ITEMID"],
104
+ "desc": ["D_ITEMS" + self.ext],
105
+ "desc_key": ["LABEL"],
106
+ },
107
+ ]
108
+
109
+ if self.feature == "select":
110
+ extra_exclude_feature_dict = {
111
+ "LABEVENTS" + self.ext: ["VALUE", "FLAG"],
112
+ "PRESCRIPTIONS"
113
+ + self.ext: [
114
+ "DRUG_TYPE",
115
+ "DRUG_NAME_POE",
116
+ "DRUG_NAME_GENERIC",
117
+ "FORMULARY_DRUG_CD",
118
+ "FORM_VAL_DISP",
119
+ "FORM_UNIT_DISP",
120
+ ],
121
+ "INPUTEVENTS_MV"
122
+ + self.ext: [
123
+ "AMOUNT",
124
+ "AMOUNTUOM",
125
+ "ORDERCATEGORYNAME",
126
+ "SECONDARYORDERCATEGORYNAME",
127
+ "ORDERCOMPONENTTYPEDESCRIPTION",
128
+ "ORDERCATEGORYDESCRIPTION",
129
+ "PATIENTWEIGHT",
130
+ "TOTALAMOUNT",
131
+ "TOTALAMOUNTUOM",
132
+ "ISOPENBAG",
133
+ "COMMENTS_EDITEDBY",
134
+ "ORIGINALAMOUNT",
135
+ "ORIGINALRATE",
136
+ ],
137
+ "INPUTEVENTS_CV"
138
+ + self.ext: [
139
+ "AMOUNT",
140
+ "AMOUNTUOM",
141
+ "NEWBOTTLE",
142
+ "ORIGINALAMOUNT",
143
+ "ORIGINALRATE",
144
+ "ORIGIANLAMOUNTUOM",
145
+ "ORIGINALRATE",
146
+ "ORIGINALRATEUOM",
147
+ "ORIGINALSITE",
148
+ ],
149
+ }
150
+
151
+ for table in self.tables:
152
+ if table["fname"] in extra_exclude_feature_dict.keys():
153
+ exclude_target_list = extra_exclude_feature_dict[table["fname"]]
154
+ table["exclude"].extend(exclude_target_list)
155
+
156
+ if self.emb_type == "codebase":
157
+ feature_types_for_codebase_emb_dict = {
158
+ "LABEVENTS"
159
+ + self.ext: {
160
+ "numeric_feat": ["VALUE", "VALUENUM"],
161
+ "categorical_feat": [],
162
+ "code_feat": ["ITEMID"],
163
+ },
164
+ "PRESCRIPTIONS"
165
+ + self.ext: {
166
+ "numeric_feat": ["DOSE_VAL_RX", "FORM_VAL_DISP"],
167
+ "categorical_feat": [],
168
+ "code_feat": ["DRUG"],
169
+ },
170
+ "INPUTEVENTS_MV"
171
+ + self.ext: {
172
+ "numeric_feat": [
173
+ "AMOUNT",
174
+ "RATE",
175
+ "PATIENTWEIGHT",
176
+ "TOTALAMOUNT",
177
+ "ORIGINALAMOUNT",
178
+ "ORIGINALRATE",
179
+ ],
180
+ "categorical_feat": ["ISOPENBAG", "CONTINUEINNEXTDEPT", "CANCELREASON", "VALUECOUNTS"],
181
+ "code_feat": ["ITEMID"],
182
+ },
183
+ "INPUTEVENTS_CV"
184
+ + self.ext: {
185
+ "numeric_feat": [
186
+ "AMOUNT",
187
+ "RATE",
188
+ "PATIENTWEIGHT",
189
+ "TOTALAMOUNT",
190
+ "ORIGINALAMOUNT",
191
+ "ORIGINALRATE",
192
+ ],
193
+ "categorical_feat": ["ISOPENBAG", "CONTINUEINNEXTDEPT", "CANCELREASON", "VALUECOUNTS"],
194
+ "code_feat": ["ITEMID"],
195
+ },
196
+ }
197
+
198
+ for table in self.tables:
199
+ if table["fname"] in feature_types_for_codebase_emb_dict.keys():
200
+ feature_dict = feature_types_for_codebase_emb_dict[table["fname"]]
201
+ table.update(feature_dict)
202
+
203
+ if self.creatinine or self.bilirubin or self.platelets or self.wbc:
204
+ self.task_itemids = {
205
+ "creatinine": {
206
+ "fname": "LABEVENTS" + self.ext,
207
+ "timestamp": "CHARTTIME",
208
+ "timeoffsetunit": "abs",
209
+ "exclude": ["ROW_ID", "SUBJECT_ID", "VALUE", "VALUEUOM", "FLAG"],
210
+ "code": ["ITEMID"],
211
+ "value": ["VALUENUM"],
212
+ "itemid": [50912],
213
+ },
214
+ "bilirubin": {
215
+ "fname": "LABEVENTS" + self.ext,
216
+ "timestamp": "CHARTTIME",
217
+ "timeoffsetunit": "abs",
218
+ "exclude": ["ROW_ID", "SUBJECT_ID", "VALUE", "VALUEUOM", "FLAG"],
219
+ "code": ["ITEMID"],
220
+ "value": ["VALUENUM"],
221
+ "itemid": [50885],
222
+ },
223
+ "platelets": {
224
+ "fname": "LABEVENTS" + self.ext,
225
+ "timestamp": "CHARTTIME",
226
+ "timeoffsetunit": "abs",
227
+ "exclude": ["ROW_ID", "SUBJECT_ID", "VALUE", "VALUEUOM", "FLAG"],
228
+ "code": ["ITEMID"],
229
+ "value": ["VALUENUM"],
230
+ "itemid": [51265],
231
+ },
232
+ "wbc": {
233
+ "fname": "LABEVENTS" + self.ext,
234
+ "timestamp": "CHARTTIME",
235
+ "timeoffsetunit": "abs",
236
+ "exclude": ["ROW_ID", "SUBJECT_ID", "VALUE", "FLAG"],
237
+ "code": ["ITEMID"],
238
+ "value": ["VALUENUM"],
239
+ "itemid": [51300, 51301],
240
+ },
241
+ "dialysis": {
242
+ "tables": {
243
+ "chartevents": {
244
+ "fname": "CHARTEVENTS" + self.ext,
245
+ "timestamp": "CHARTTIME",
246
+ "timeoffsetunit": "abs",
247
+ "include": [
248
+ "ICUSTAY_ID",
249
+ "SUBJECT_ID",
250
+ "ITEMID",
251
+ "VALUE",
252
+ "VALUENUM",
253
+ "CHARTTIME",
254
+ "ERROR",
255
+ ],
256
+ "itemid": {
257
+ "cv_ce": [
258
+ 152,
259
+ 148,
260
+ 149,
261
+ 146,
262
+ 147,
263
+ 151,
264
+ 150,
265
+ 7949,
266
+ 229,
267
+ 235,
268
+ 241,
269
+ 247,
270
+ 253,
271
+ 259,
272
+ 265,
273
+ 271,
274
+ 582,
275
+ 466,
276
+ 917,
277
+ 927,
278
+ 6250,
279
+ ],
280
+ "mv_ce": [
281
+ 226118,
282
+ 227357,
283
+ 225725,
284
+ 226499,
285
+ 224154,
286
+ 225810,
287
+ 227639,
288
+ 225183,
289
+ 227438,
290
+ 224191,
291
+ 225806,
292
+ 225807,
293
+ 228004,
294
+ 228005,
295
+ 228006,
296
+ 224144,
297
+ 224145,
298
+ 224149,
299
+ 224150,
300
+ 224151,
301
+ 224152,
302
+ 224153,
303
+ 224404,
304
+ 224406,
305
+ 226457,
306
+ 225959,
307
+ 224135,
308
+ 224139,
309
+ 224146,
310
+ 225323,
311
+ 225740,
312
+ 225776,
313
+ 225951,
314
+ 225952,
315
+ 225953,
316
+ 225954,
317
+ 225956,
318
+ 225958,
319
+ 225961,
320
+ 225963,
321
+ 225965,
322
+ 225976,
323
+ 225977,
324
+ 227124,
325
+ 227290,
326
+ 227638,
327
+ 227640,
328
+ 227753,
329
+ ],
330
+ },
331
+ },
332
+ "inputevents_cv": {
333
+ "fname": "INPUTEVENTS_CV" + self.ext,
334
+ "timestamp": "CHARTTIME",
335
+ "timeoffsetunit": "abs",
336
+ "include": ["SUBJECT_ID", "ITEMID", "AMOUNT", "CHARTTIME"],
337
+ "itemid": {
338
+ "cv_ie": [
339
+ 40788,
340
+ 40907,
341
+ 41063,
342
+ 41147,
343
+ 41307,
344
+ 41460,
345
+ 41620,
346
+ 41711,
347
+ 41791,
348
+ 41792,
349
+ 42562,
350
+ 43829,
351
+ 44037,
352
+ 44188,
353
+ 44526,
354
+ 44527,
355
+ 44584,
356
+ 44591,
357
+ 44698,
358
+ 44927,
359
+ 44954,
360
+ 45157,
361
+ 45268,
362
+ 45352,
363
+ 45353,
364
+ 46012,
365
+ 46013,
366
+ 46172,
367
+ 46173,
368
+ 46250,
369
+ 46262,
370
+ 46292,
371
+ 46293,
372
+ 46311,
373
+ 46389,
374
+ 46574,
375
+ 46681,
376
+ 46720,
377
+ 46769,
378
+ 46773,
379
+ ]
380
+ },
381
+ },
382
+ "outputevents": {
383
+ "fname": "OUTPUTEVENTS" + self.ext,
384
+ "timestamp": "CHARTTIME",
385
+ "timeoffsetunit": "abs",
386
+ "include": ["SUBJECT_ID", "ITEMID", "VALUE", "CHARTTIME"],
387
+ "itemid": {
388
+ "cv_oe": [
389
+ 40386,
390
+ 40425,
391
+ 40426,
392
+ 40507,
393
+ 40613,
394
+ 40624,
395
+ 40690,
396
+ 40745,
397
+ 40789,
398
+ 40881,
399
+ 40910,
400
+ 41016,
401
+ 41034,
402
+ 41069,
403
+ 41112,
404
+ 41250,
405
+ 41374,
406
+ 41417,
407
+ 41500,
408
+ 41527,
409
+ 41623,
410
+ 41635,
411
+ 41713,
412
+ 41750,
413
+ 41829,
414
+ 41842,
415
+ 41897,
416
+ 42289,
417
+ 42388,
418
+ 42464,
419
+ 42524,
420
+ 42536,
421
+ 42868,
422
+ 42928,
423
+ 42972,
424
+ 43016,
425
+ 43052,
426
+ 43098,
427
+ 43115,
428
+ 43687,
429
+ 43941,
430
+ 44027,
431
+ 44085,
432
+ 44193,
433
+ 44199,
434
+ 44216,
435
+ 44286,
436
+ 44567,
437
+ 44843,
438
+ 44845,
439
+ 44857,
440
+ 44901,
441
+ 44943,
442
+ 45479,
443
+ 45828,
444
+ 46230,
445
+ 46232,
446
+ 46394,
447
+ 46464,
448
+ 46712,
449
+ 46713,
450
+ 46715,
451
+ 46741,
452
+ ],
453
+ },
454
+ },
455
+ "inputevents_mv": {
456
+ "fname": "INPUTEVENTS_MV" + self.ext,
457
+ "timestamp": "STARTTIME",
458
+ "timeoffsetunit": "abs",
459
+ "include": ["SUBJECT_ID", "ITEMID", "AMOUNT", "STARTTIME"],
460
+ "itemid": {"mv_ie": [227536, 227525]},
461
+ },
462
+ "datetimeevents": {
463
+ "fname": "DATETIMEEVENTS" + self.ext,
464
+ "timestamp": "CHARTTIME",
465
+ "timeoffsetunit": "abs",
466
+ "include": ["SUBJECT_ID", "ITEMID", "CHARTTIME"],
467
+ "itemid": {"mv_de": [225318, 225319, 225321, 225322, 225324]},
468
+ },
469
+ "procedureevents_mv": {
470
+ "fname": "PROCEDUREEVENTS_MV" + self.ext,
471
+ "timestamp": "STARTTIME",
472
+ "timeoffsetunit": "abs",
473
+ "include": ["SUBJECT_ID", "ITEMID", "STARTTIME"],
474
+ "itemid": {
475
+ "mv_pe": [225441, 225802, 225803, 225805, 224270, 225809, 225955, 225436]
476
+ },
477
+ },
478
+ }
479
+ },
480
+ }
481
+
482
+ self.disch_map_dict = {
483
+ "DISC-TRAN CANCER/CHLDRN H": "Other",
484
+ "DISC-TRAN TO FEDERAL HC": "Other",
485
+ "DISCH-TRAN TO PSYCH HOSP": "Other",
486
+ "HOME": "Home",
487
+ "HOME HEALTH CARE": "Home",
488
+ "HOME WITH HOME IV PROVIDR": "Home",
489
+ "HOSPICE-HOME": "Other",
490
+ "HOSPICE-MEDICAL FACILITY": "Other",
491
+ "ICF": "Other",
492
+ "IN_ICU_MORTALITY": "IN_ICU_MORTALITY",
493
+ "LEFT AGAINST MEDICAL ADVI": "Other",
494
+ "LONG TERM CARE HOSPITAL": "Other",
495
+ "OTHER FACILITY": "Other",
496
+ "REHAB/DISTINCT PART HOSP": "Rehabilitation",
497
+ "SHORT TERM HOSPITAL": "Other",
498
+ "SNF": "Skilled Nursing Facility",
499
+ "SNF-MEDICAID ONLY CERTIF": "Skilled Nursing Facility",
500
+ "Death": "Death",
501
+ }
502
+ self._icustay_key = "ICUSTAY_ID"
503
+ self._hadm_key = "HADM_ID"
504
+ self._patient_key = "SUBJECT_ID"
505
+
506
+ self._determine_first_icu = "INTIME"
507
+
508
+ def build_cohorts(self, cached=False):
509
+ icustays = pd.read_csv(os.path.join(self.data_dir, self.icustay_fname))
510
+
511
+ icustays = self.make_compatible(icustays)
512
+ self.icustays = icustays
513
+
514
+ cohorts = super().build_cohorts(icustays, cached=cached)
515
+
516
+ return cohorts
517
+
518
+ def prepare_tasks(self, cohorts, spark, cached=False):
519
+ if cohorts is None and cached:
520
+ labeled_cohorts = self.load_from_cache(self.ehr_name + ".cohorts.labeled")
521
+ if labeled_cohorts is not None:
522
+ return labeled_cohorts
523
+
524
+ labeled_cohorts = super().prepare_tasks(cohorts, spark, cached)
525
+
526
+ if self.diagnosis:
527
+ logger.info("Start labeling cohorts for diagnosis prediction.")
528
+ # define diagnosis prediction task
529
+ diagnoses = pd.read_csv(os.path.join(self.data_dir, self.diagnosis_fname))
530
+
531
+ ccs_dx = pd.read_csv(self.ccs_path)
532
+ ccs_dx["'ICD-9-CM CODE'"] = ccs_dx["'ICD-9-CM CODE'"].str[1:-1].str.strip()
533
+ ccs_dx["'CCS LVL 1'"] = ccs_dx["'CCS LVL 1'"].str[1:-1]
534
+ lvl1 = {x: int(y) - 1 for _, (x, y) in ccs_dx[["'ICD-9-CM CODE'", "'CCS LVL 1'"]].iterrows()}
535
+ diagnoses["diagnosis"] = diagnoses["ICD9_CODE"].map(lvl1)
536
+
537
+ diagnoses = diagnoses[(diagnoses["diagnosis"].notnull()) & (diagnoses["diagnosis"] != 14)]
538
+ diagnoses.loc[diagnoses["diagnosis"] >= 14, "diagnosis"] -= 1
539
+ diagnoses = diagnoses.groupby(self.hadm_key)["diagnosis"].agg(lambda x: list(set(x))).to_frame()
540
+ labeled_cohorts = labeled_cohorts.merge(diagnoses, on=self.hadm_key, how="inner")
541
+
542
+ # labeled_cohorts = labeled_cohorts.drop(columns=["ICD9_CODE"])
543
+
544
+ logger.info(
545
+ "Done preparing diagnosis prediction for the given cohorts, Cohort Numbers: {}".format(
546
+ len(labeled_cohorts)
547
+ )
548
+ )
549
+
550
+ self.save_to_cache(labeled_cohorts, self.ehr_name + ".cohorts.labeled")
551
+
552
+ if self.bilirubin or self.platelets or self.creatinine or self.wbc:
553
+ logger.info("Start labeling cohorts for clinical task prediction.")
554
+
555
+ labeled_cohorts = spark.createDataFrame(labeled_cohorts)
556
+
557
+ if self.bilirubin:
558
+ labeled_cohorts = self.clinical_task(labeled_cohorts, "bilirubin", spark)
559
+
560
+ if self.platelets:
561
+ labeled_cohorts = self.clinical_task(labeled_cohorts, "platelets", spark)
562
+
563
+ if self.creatinine:
564
+ labeled_cohorts = self.clinical_task(labeled_cohorts, "creatinine", spark)
565
+
566
+ if self.wbc:
567
+ labeled_cohorts = self.clinical_task(labeled_cohorts, "wbc", spark)
568
+ # labeled_cohorts = labeled_cohorts.toPandas()
569
+
570
+ # self.save_to_cache(labeled_cohorts, self.ehr_name + ".cohorts.labeled.clinical_tasks")
571
+
572
+ logger.info("Done preparing clinical task prediction for the given cohorts")
573
+
574
+ # self.save_to_cache(labeled_cohorts, self.ehr_name + ".cohorts.labeled")
575
+ return labeled_cohorts
576
+
577
+ def make_compatible(self, icustays):
578
+ patients = pd.read_csv(os.path.join(self.data_dir, self.patient_fname))
579
+ admissions = pd.read_csv(os.path.join(self.data_dir, self.admission_fname))
580
+
581
+ # prepare icustays according to the appropriate format
582
+ icustays = icustays[icustays["FIRST_CAREUNIT"] == icustays["LAST_CAREUNIT"]]
583
+
584
+ icustays["INTIME"] = pd.to_datetime(icustays["INTIME"], infer_datetime_format=True, utc=True)
585
+
586
+ icustays["OUTTIME"] = pd.to_datetime(icustays["OUTTIME"], infer_datetime_format=True, utc=True)
587
+ icustays = icustays.drop(columns=["ROW_ID"])
588
+
589
+ # merge icustays with patients to get DOB
590
+ patients["DOB"] = pd.to_datetime(patients["DOB"], infer_datetime_format=True, utc=True)
591
+ patients = patients[patients["SUBJECT_ID"].isin(icustays["SUBJECT_ID"])]
592
+ patients = patients.drop(columns=["ROW_ID"])[["DOB", "SUBJECT_ID"]]
593
+ icustays = icustays.merge(patients, on="SUBJECT_ID", how="left")
594
+
595
+ def calculate_age(birth: datetime, now: datetime):
596
+ age = now.year - birth.year
597
+ if now.month < birth.month:
598
+ age -= 1
599
+ elif (now.month == birth.month) and (now.day < birth.day):
600
+ age -= 1
601
+
602
+ return age
603
+
604
+ icustays["AGE"] = icustays.apply(lambda x: calculate_age(x["DOB"], x["INTIME"]), axis=1)
605
+
606
+ # merge with admissions to get discharge information
607
+ icustays = pd.merge(
608
+ icustays.reset_index(drop=True),
609
+ admissions[["HADM_ID", "DISCHARGE_LOCATION", "DEATHTIME", "DISCHTIME"]],
610
+ how="left",
611
+ on="HADM_ID",
612
+ )
613
+ icustays["DISCHARGE_LOCATION"].replace("DEAD/EXPIRED", "Death", inplace=True)
614
+
615
+ icustays["DISCHTIME"] = pd.to_datetime(icustays["DISCHTIME"], infer_datetime_format=True, utc=True)
616
+
617
+ icustays["IN_ICU_MORTALITY"] = (
618
+ (icustays["INTIME"] < icustays["DISCHTIME"])
619
+ & (icustays["DISCHTIME"] <= icustays["OUTTIME"])
620
+ & (icustays["DISCHARGE_LOCATION"] == "Death")
621
+ )
622
+ icustays["DISCHARGE_LOCATION"] = icustays["DISCHARGE_LOCATION"].map(self.disch_map_dict)
623
+
624
+ icustays.rename(columns={"DISCHARGE_LOCATION": "HOS_DISCHARGE_LOCATION"}, inplace=True)
625
+ icustays["DISCHTIME"] = (icustays["DISCHTIME"] - icustays["INTIME"]).dt.total_seconds() // 60
626
+ icustays["OUTTIME"] = (icustays["OUTTIME"] - icustays["INTIME"]).dt.total_seconds() // 60
627
+ return icustays
628
+
629
+ def clinical_task(self, cohorts, task, spark):
630
+ fname = self.task_itemids[task]["fname"]
631
+ timestamp = self.task_itemids[task]["timestamp"]
632
+ timeoffsetunit = self.task_itemids[task]["timeoffsetunit"]
633
+ excludes = self.task_itemids[task]["exclude"]
634
+ code = self.task_itemids[task]["code"][0]
635
+ value = self.task_itemids[task]["value"][0]
636
+ itemid = self.task_itemids[task]["itemid"]
637
+
638
+ table = spark.read.csv(os.path.join(self.data_dir, fname), header=True)
639
+ table = table.drop(*excludes)
640
+ table = table.filter(F.col(code).isin(itemid)).filter(F.col(value).isNotNull())
641
+
642
+ merge = cohorts.join(table, on=self.hadm_key, how="inner")
643
+
644
+ if timeoffsetunit == "abs":
645
+ merge = merge.withColumn(timestamp, F.to_timestamp(timestamp))
646
+
647
+ # For Creatinine task, eliminate icus if patient went through dialysis treatment
648
+ # before (obs_size + pred_size) timestamp
649
+ # see https://github.com/MIT-LCP/mimic-code/blob/main/mimic-iii/concepts/rrt.sql
650
+ if task == "creatinine":
651
+ dialysis_tables = self.task_itemids["dialysis"]["tables"]
652
+ chartevents = spark.read.csv(os.path.join(self.data_dir, "CHARTEVENTS" + self.ext), header=True)
653
+ inputevents_cv = spark.read.csv(
654
+ os.path.join(self.data_dir, "INPUTEVENTS_CV" + self.ext), header=True
655
+ )
656
+ outputevents = spark.read.csv(os.path.join(self.data_dir, "OUTPUTEVENTS" + self.ext), header=True)
657
+ inputevents_mv = spark.read.csv(
658
+ os.path.join(self.data_dir, "INPUTEVENTS_MV" + self.ext), header=True
659
+ )
660
+ datetimeevents = spark.read.csv(
661
+ os.path.join(self.data_dir, "DATETIMEEVENTS" + self.ext), header=True
662
+ )
663
+ procedureevents_mv = spark.read.csv(
664
+ os.path.join(self.data_dir, "PROCEDUREEVENTS_MV" + self.ext), header=True
665
+ )
666
+ icustays = spark.read.csv(os.path.join(self.data_dir, "ICUSTAYS" + self.ext), header=True)
667
+
668
+ chartevents = chartevents.select(*dialysis_tables["chartevents"]["include"])
669
+ inputevents_cv = inputevents_cv.select(*dialysis_tables["inputevents_cv"]["include"])
670
+ outputevents = outputevents.select(*dialysis_tables["outputevents"]["include"])
671
+ inputevents_mv = inputevents_mv.select(*dialysis_tables["inputevents_mv"]["include"])
672
+ datetimeevents = datetimeevents.select(*dialysis_tables["datetimeevents"]["include"])
673
+ procedureevents_mv = procedureevents_mv.select(*dialysis_tables["procedureevents_mv"]["include"])
674
+
675
+ # Filter dialysis related tables with dialysis condition #TODO: check dialysis condition
676
+ cv_ce = (
677
+ chartevents.filter(F.col("ITEMID").isin(dialysis_tables["chartevents"]["itemid"]["cv_ce"]))
678
+ .filter(F.col("VALUE").isNotNull())
679
+ .filter((F.col("ERROR").isNull()) | (F.col("ERROR") == 0))
680
+ .filter(
681
+ (
682
+ (F.col("ITEMID").isin([152, 148, 149, 146, 147, 151, 150]))
683
+ & (F.col("VALUE").isNotNull())
684
+ )
685
+ | (
686
+ (F.col("ITEMID").isin([229, 235, 241, 247, 253, 259, 265, 271]))
687
+ & (F.col("VALUE") == "Dialysis Line")
688
+ )
689
+ | ((F.col("ITEMID") == 466) & (F.col("VALUE") == "Dialysis RN"))
690
+ | ((F.col("ITEMID") == 927) & (F.col("VALUE") == "Dialysis Solutions"))
691
+ | ((F.col("ITEMID") == 6250) & (F.col("VALUE") == "dialys"))
692
+ | (
693
+ (F.col("ITEMID") == 917)
694
+ & (
695
+ F.col("VALUE").isin(
696
+ [
697
+ "+ INITIATE DIALYSIS",
698
+ "BLEEDING FROM DIALYSIS CATHETER",
699
+ "FAILED DIALYSIS CATH.",
700
+ "FEBRILE SYNDROME;DIALYSIS",
701
+ "HYPOTENSION WITH HEMODIALYSIS",
702
+ "HYPOTENSION.GLOGGED DIALYSIS",
703
+ "INFECTED DIALYSIS CATHETER",
704
+ ]
705
+ )
706
+ )
707
+ )
708
+ | (
709
+ (F.col("ITEMID") == 582)
710
+ & (
711
+ F.col("VALUE").isin(
712
+ [
713
+ "CAVH Start",
714
+ "CAVH D/C",
715
+ "CVVHD Start",
716
+ "CVVHD D/C",
717
+ "Hemodialysis st",
718
+ "Hemodialysis end",
719
+ ]
720
+ )
721
+ )
722
+ )
723
+ )
724
+ )
725
+ icustays = icustays.filter(F.col("DBSOURCE") == "carevue").select(self.icustay_key)
726
+ cv_ce = cv_ce.join(icustays, on=self.icustay_key, how="inner")
727
+ cv_ie = inputevents_cv.filter(
728
+ F.col("ITEMID").isin(dialysis_tables["inputevents_cv"]["itemid"]["cv_ie"])
729
+ ).filter(F.col("AMOUNT") > 0)
730
+ cv_oe = outputevents.filter(
731
+ F.col("ITEMID").isin(dialysis_tables["outputevents"]["itemid"]["cv_oe"])
732
+ ).filter(F.col("VALUE") > 0)
733
+ mv_ce = (
734
+ chartevents.filter(F.col("ITEMID").isin(dialysis_tables["chartevents"]["itemid"]["mv_ce"]))
735
+ .filter(F.col("VALUENUM") > 0)
736
+ .filter((F.col("ERROR").isNull()) | (F.col("ERROR") == 0))
737
+ )
738
+ mv_ie = inputevents_mv.filter(
739
+ F.col("ITEMID").isin(dialysis_tables["inputevents_mv"]["itemid"]["mv_ie"])
740
+ ).filter(F.col("AMOUNT") > 0)
741
+ mv_de = datetimeevents.filter(
742
+ F.col("ITEMID").isin(dialysis_tables["datetimeevents"]["itemid"]["mv_de"])
743
+ )
744
+ mv_pe = procedureevents_mv.filter(
745
+ F.col("ITEMID").isin(dialysis_tables["procedureevents_mv"]["itemid"]["mv_pe"])
746
+ )
747
+
748
+ def dialysis_time(table, timecolumn):
749
+ return table.withColumn("_DIALYSIS_TIME", F.to_timestamp(timecolumn)).select(
750
+ self.patient_key, "_DIALYSIS_TIME"
751
+ )
752
+
753
+ cv_ce, cv_ie, cv_oe, mv_ce, mv_ie, mv_de, mv_pe = (
754
+ dialysis_time(cv_ce, "CHARTTIME"),
755
+ dialysis_time(cv_ie, "CHARTTIME"),
756
+ dialysis_time(cv_oe, "CHARTTIME"),
757
+ dialysis_time(mv_ce, "CHARTTIME"),
758
+ dialysis_time(mv_ie, "STARTTIME"),
759
+ dialysis_time(mv_de, "CHARTTIME"),
760
+ dialysis_time(mv_pe, "STARTTIME"),
761
+ )
762
+
763
+ dialysis = cv_ce.union(cv_ie).union(cv_oe).union(mv_ce).union(mv_ie).union(mv_de).union(mv_pe)
764
+ dialysis = dialysis.groupBy(self.patient_key).agg(F.min("_DIALYSIS_TIME").alias("_DIALYSIS_TIME"))
765
+ merge = merge.join(dialysis, on=self.patient_key, how="left")
766
+ merge = merge.filter(F.isnull("_DIALYSIS_TIME") | (F.col("_DIALYSIS_TIME") > F.col(timestamp)))
767
+ merge = merge.drop("_DIALYSIS_TIME")
768
+
769
+ if timeoffsetunit == "abs":
770
+ merge = merge.withColumn(
771
+ timestamp, F.round((F.col(timestamp).cast("long") - F.col("INTIME").cast("long")) / 60)
772
+ )
773
+
774
+ # Cohort with events within (obs_size + gap_size) - (obs_size + pred_size)
775
+ merge = merge.filter(((self.obs_size + self.gap_size) * 60) <= F.col(timestamp)).filter(
776
+ ((self.obs_size + self.pred_size) * 60) >= F.col(timestamp)
777
+ )
778
+
779
+ # Average value of events
780
+ value_agg = merge.groupBy(self.icustay_key).agg(
781
+ F.mean(value).alias("avg_value")
782
+ ) # TODO: mean/min/max?
783
+
784
+ # Labeling
785
+ if task == "bilirubin":
786
+ value_agg = value_agg.withColumn(
787
+ task,
788
+ F.when(value_agg.avg_value < 1.2, 0)
789
+ .when((value_agg.avg_value >= 1.2) & (value_agg.avg_value < 2.0), 1)
790
+ .when((value_agg.avg_value >= 2.0) & (value_agg.avg_value < 6.0), 2)
791
+ .when((value_agg.avg_value >= 6.0) & (value_agg.avg_value < 12.0), 3)
792
+ .when(value_agg.avg_value >= 12.0, 4),
793
+ )
794
+ elif task == "platelets":
795
+ value_agg = value_agg.withColumn(
796
+ task,
797
+ F.when(value_agg.avg_value >= 150, 0)
798
+ .when((value_agg.avg_value >= 100) & (value_agg.avg_value < 150), 1)
799
+ .when((value_agg.avg_value >= 50) & (value_agg.avg_value < 100), 2)
800
+ .when((value_agg.avg_value >= 20) & (value_agg.avg_value < 50), 3)
801
+ .when(value_agg.avg_value < 20, 4),
802
+ )
803
+
804
+ elif task == "creatinine":
805
+ value_agg = value_agg.withColumn(
806
+ task,
807
+ F.when(value_agg.avg_value < 1.2, 0)
808
+ .when((value_agg.avg_value >= 1.2) & (value_agg.avg_value < 2.0), 1)
809
+ .when((value_agg.avg_value >= 2.0) & (value_agg.avg_value < 3.5), 2)
810
+ .when((value_agg.avg_value >= 3.5) & (value_agg.avg_value < 5), 3)
811
+ .when(value_agg.avg_value >= 5, 4),
812
+ )
813
+
814
+ elif task == "wbc":
815
+ value_agg = value_agg.withColumn(
816
+ task,
817
+ F.when(value_agg.avg_value < 4, 0)
818
+ .when((value_agg.avg_value >= 4) & (value_agg.avg_value <= 12), 1)
819
+ .when((value_agg.avg_value > 12), 2),
820
+ )
821
+
822
+ cohorts = cohorts.join(value_agg.select(self.icustay_key, task), on=self.icustay_key, how="left")
823
+
824
+ return cohorts
825
+
826
+ def infer_data_extension(self) -> str:
827
+ if len(glob.glob(os.path.join(self.data_dir, "*.csv.gz"))) == 26:
828
+ ext = ".csv.gz"
829
+ elif len(glob.glob(os.path.join(self.data_dir, "*.csv"))) == 26:
830
+ ext = ".csv"
831
+ else:
832
+ raise AssertionError(
833
+ "Provided data directory is not correct. Please check if --data is correct. "
834
+ "--data: {}".format(self.data_dir)
835
+ )
836
+
837
+ logger.info("Data extension is set to '{}'".format(ext))
838
+
839
+ return ext