genhpf 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of genhpf might be problematic. Click here for more details.

Files changed (67) hide show
  1. genhpf/__init__.py +9 -0
  2. genhpf/configs/__init__.py +23 -0
  3. genhpf/configs/config.yaml +8 -0
  4. genhpf/configs/configs.py +240 -0
  5. genhpf/configs/constants.py +29 -0
  6. genhpf/configs/initialize.py +58 -0
  7. genhpf/configs/utils.py +29 -0
  8. genhpf/criterions/__init__.py +74 -0
  9. genhpf/criterions/binary_cross_entropy.py +114 -0
  10. genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
  11. genhpf/criterions/criterion.py +87 -0
  12. genhpf/criterions/cross_entropy.py +202 -0
  13. genhpf/criterions/multi_task_criterion.py +177 -0
  14. genhpf/criterions/simclr_criterion.py +84 -0
  15. genhpf/criterions/wav2vec2_criterion.py +130 -0
  16. genhpf/datasets/__init__.py +84 -0
  17. genhpf/datasets/dataset.py +109 -0
  18. genhpf/datasets/genhpf_dataset.py +451 -0
  19. genhpf/datasets/meds_dataset.py +232 -0
  20. genhpf/loggings/__init__.py +0 -0
  21. genhpf/loggings/meters.py +374 -0
  22. genhpf/loggings/metrics.py +155 -0
  23. genhpf/loggings/progress_bar.py +445 -0
  24. genhpf/models/__init__.py +73 -0
  25. genhpf/models/genhpf.py +233 -0
  26. genhpf/models/genhpf_mlm.py +64 -0
  27. genhpf/models/genhpf_predictor.py +73 -0
  28. genhpf/models/genhpf_simclr.py +58 -0
  29. genhpf/models/genhpf_wav2vec2.py +304 -0
  30. genhpf/modules/__init__.py +15 -0
  31. genhpf/modules/gather_layer.py +23 -0
  32. genhpf/modules/grad_multiply.py +12 -0
  33. genhpf/modules/gumbel_vector_quantizer.py +204 -0
  34. genhpf/modules/identity_layer.py +8 -0
  35. genhpf/modules/layer_norm.py +27 -0
  36. genhpf/modules/positional_encoding.py +24 -0
  37. genhpf/scripts/__init__.py +0 -0
  38. genhpf/scripts/preprocess/__init__.py +0 -0
  39. genhpf/scripts/preprocess/genhpf/README.md +75 -0
  40. genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
  41. genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
  42. genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
  43. genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
  44. genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
  45. genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
  46. genhpf/scripts/preprocess/genhpf/main.py +174 -0
  47. genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
  48. genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
  49. genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
  50. genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
  51. genhpf/scripts/preprocess/manifest.py +83 -0
  52. genhpf/scripts/preprocess/preprocess_meds.py +584 -0
  53. genhpf/scripts/test.py +261 -0
  54. genhpf/scripts/train.py +350 -0
  55. genhpf/trainer.py +370 -0
  56. genhpf/utils/checkpoint_utils.py +171 -0
  57. genhpf/utils/data_utils.py +130 -0
  58. genhpf/utils/distributed_utils.py +497 -0
  59. genhpf/utils/file_io.py +170 -0
  60. genhpf/utils/pdb.py +38 -0
  61. genhpf/utils/utils.py +204 -0
  62. genhpf-1.0.0.dist-info/LICENSE +21 -0
  63. genhpf-1.0.0.dist-info/METADATA +197 -0
  64. genhpf-1.0.0.dist-info/RECORD +67 -0
  65. genhpf-1.0.0.dist-info/WHEEL +5 -0
  66. genhpf-1.0.0.dist-info/entry_points.txt +6 -0
  67. genhpf-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,839 @@
1
+ import glob
2
+ import logging
3
+ import os
4
+ from datetime import datetime
5
+
6
+ import pandas as pd
7
+ import pyspark.sql.functions as F
8
+ from ehrs import EHR, register_ehr
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ @register_ehr("mimiciii")
14
+ class MIMICIII(EHR):
15
+ def __init__(self, cfg):
16
+ super().__init__(cfg)
17
+
18
+ self.ehr_name = "mimiciii"
19
+
20
+ if self.data_dir is None:
21
+ self.data_dir = os.path.join(self.cache_dir, self.ehr_name)
22
+
23
+ if not os.path.exists(self.data_dir):
24
+ logger.info(
25
+ "Data is not found so try to download from the internet. "
26
+ "Note that this is a restricted-access resource. "
27
+ "Please log in to physionet.org with a credentialed user."
28
+ )
29
+ self.download_ehr_from_url(
30
+ url="https://physionet.org/files/mimiciii/1.4/", dest=self.data_dir
31
+ )
32
+
33
+ logger.info("Data directory is set to {}".format(self.data_dir))
34
+
35
+ if self.ccs_path is None:
36
+ self.ccs_path = os.path.join(self.cache_dir, "ccs_multi_dx_tool_2015.csv")
37
+
38
+ if not os.path.exists(self.ccs_path):
39
+ logger.info("`ccs_multi_dx_tool_2015.csv` is not found so try to download from the internet.")
40
+ self.download_ccs_from_url(self.cache_dir)
41
+
42
+ if self.ext is None:
43
+ self.ext = self.infer_data_extension()
44
+
45
+ # constants
46
+ self._icustay_fname = "ICUSTAYS" + self.ext
47
+ self._patient_fname = "PATIENTS" + self.ext
48
+ self._admission_fname = "ADMISSIONS" + self.ext
49
+ self._diagnosis_fname = "DIAGNOSES_ICD" + self.ext
50
+
51
+ # XXX more features? user choice?
52
+ self.tables = [
53
+ {
54
+ "fname": "LABEVENTS" + self.ext,
55
+ "timestamp": "CHARTTIME",
56
+ "timeoffsetunit": "abs",
57
+ "exclude": ["ROW_ID", "SUBJECT_ID"],
58
+ "code": ["ITEMID"],
59
+ "desc": ["D_LABITEMS" + self.ext],
60
+ "desc_key": ["LABEL"],
61
+ },
62
+ {
63
+ "fname": "PRESCRIPTIONS" + self.ext,
64
+ "timestamp": "STARTDATE",
65
+ "timeoffsetunit": "abs",
66
+ "exclude": ["ENDDATE", "GSN", "NDC", "ROW_ID", "SUBJECT_ID"],
67
+ },
68
+ {
69
+ "fname": "INPUTEVENTS_MV" + self.ext,
70
+ "timestamp": "STARTTIME",
71
+ "timeoffsetunit": "abs",
72
+ "exclude": [
73
+ "ENDTIME",
74
+ "STORETIME",
75
+ "CGID",
76
+ "ORDERID",
77
+ "LINKORDERID",
78
+ "ROW_ID",
79
+ "SUBJECT_ID",
80
+ "CONTINUEINNEXTDEPT",
81
+ "CANCELREASON",
82
+ "STATUSDESCRIPTION",
83
+ "COMMENTS_CANCELEDBY",
84
+ "COMMENTS_DATE",
85
+ ],
86
+ "code": ["ITEMID"],
87
+ "desc": ["D_ITEMS" + self.ext],
88
+ "desc_key": ["LABEL"],
89
+ },
90
+ {
91
+ "fname": "INPUTEVENTS_CV" + self.ext,
92
+ "timestamp": "CHARTTIME",
93
+ "timeoffsetunit": "abs",
94
+ "exclude": [
95
+ "STORETIME",
96
+ "CGID",
97
+ "ORDERID",
98
+ "LINKORDERID",
99
+ "ROW_ID",
100
+ "STOPPED",
101
+ "SUBJECT_ID",
102
+ ],
103
+ "code": ["ITEMID"],
104
+ "desc": ["D_ITEMS" + self.ext],
105
+ "desc_key": ["LABEL"],
106
+ },
107
+ ]
108
+
109
+ if self.feature == "select":
110
+ extra_exclude_feature_dict = {
111
+ "LABEVENTS" + self.ext: ["VALUE", "FLAG"],
112
+ "PRESCRIPTIONS"
113
+ + self.ext: [
114
+ "DRUG_TYPE",
115
+ "DRUG_NAME_POE",
116
+ "DRUG_NAME_GENERIC",
117
+ "FORMULARY_DRUG_CD",
118
+ "FORM_VAL_DISP",
119
+ "FORM_UNIT_DISP",
120
+ ],
121
+ "INPUTEVENTS_MV"
122
+ + self.ext: [
123
+ "AMOUNT",
124
+ "AMOUNTUOM",
125
+ "ORDERCATEGORYNAME",
126
+ "SECONDARYORDERCATEGORYNAME",
127
+ "ORDERCOMPONENTTYPEDESCRIPTION",
128
+ "ORDERCATEGORYDESCRIPTION",
129
+ "PATIENTWEIGHT",
130
+ "TOTALAMOUNT",
131
+ "TOTALAMOUNTUOM",
132
+ "ISOPENBAG",
133
+ "COMMENTS_EDITEDBY",
134
+ "ORIGINALAMOUNT",
135
+ "ORIGINALRATE",
136
+ ],
137
+ "INPUTEVENTS_CV"
138
+ + self.ext: [
139
+ "AMOUNT",
140
+ "AMOUNTUOM",
141
+ "NEWBOTTLE",
142
+ "ORIGINALAMOUNT",
143
+ "ORIGINALRATE",
144
+ "ORIGIANLAMOUNTUOM",
145
+ "ORIGINALRATE",
146
+ "ORIGINALRATEUOM",
147
+ "ORIGINALSITE",
148
+ ],
149
+ }
150
+
151
+ for table in self.tables:
152
+ if table["fname"] in extra_exclude_feature_dict.keys():
153
+ exclude_target_list = extra_exclude_feature_dict[table["fname"]]
154
+ table["exclude"].extend(exclude_target_list)
155
+
156
+ if self.emb_type == "codebase":
157
+ feature_types_for_codebase_emb_dict = {
158
+ "LABEVENTS"
159
+ + self.ext: {
160
+ "numeric_feat": ["VALUE", "VALUENUM"],
161
+ "categorical_feat": [],
162
+ "code_feat": ["ITEMID"],
163
+ },
164
+ "PRESCRIPTIONS"
165
+ + self.ext: {
166
+ "numeric_feat": ["DOSE_VAL_RX", "FORM_VAL_DISP"],
167
+ "categorical_feat": [],
168
+ "code_feat": ["DRUG"],
169
+ },
170
+ "INPUTEVENTS_MV"
171
+ + self.ext: {
172
+ "numeric_feat": [
173
+ "AMOUNT",
174
+ "RATE",
175
+ "PATIENTWEIGHT",
176
+ "TOTALAMOUNT",
177
+ "ORIGINALAMOUNT",
178
+ "ORIGINALRATE",
179
+ ],
180
+ "categorical_feat": ["ISOPENBAG", "CONTINUEINNEXTDEPT", "CANCELREASON", "VALUECOUNTS"],
181
+ "code_feat": ["ITEMID"],
182
+ },
183
+ "INPUTEVENTS_CV"
184
+ + self.ext: {
185
+ "numeric_feat": [
186
+ "AMOUNT",
187
+ "RATE",
188
+ "PATIENTWEIGHT",
189
+ "TOTALAMOUNT",
190
+ "ORIGINALAMOUNT",
191
+ "ORIGINALRATE",
192
+ ],
193
+ "categorical_feat": ["ISOPENBAG", "CONTINUEINNEXTDEPT", "CANCELREASON", "VALUECOUNTS"],
194
+ "code_feat": ["ITEMID"],
195
+ },
196
+ }
197
+
198
+ for table in self.tables:
199
+ if table["fname"] in feature_types_for_codebase_emb_dict.keys():
200
+ feature_dict = feature_types_for_codebase_emb_dict[table["fname"]]
201
+ table.update(feature_dict)
202
+
203
+ if self.creatinine or self.bilirubin or self.platelets or self.wbc:
204
+ self.task_itemids = {
205
+ "creatinine": {
206
+ "fname": "LABEVENTS" + self.ext,
207
+ "timestamp": "CHARTTIME",
208
+ "timeoffsetunit": "abs",
209
+ "exclude": ["ROW_ID", "SUBJECT_ID", "VALUE", "VALUEUOM", "FLAG"],
210
+ "code": ["ITEMID"],
211
+ "value": ["VALUENUM"],
212
+ "itemid": [50912],
213
+ },
214
+ "bilirubin": {
215
+ "fname": "LABEVENTS" + self.ext,
216
+ "timestamp": "CHARTTIME",
217
+ "timeoffsetunit": "abs",
218
+ "exclude": ["ROW_ID", "SUBJECT_ID", "VALUE", "VALUEUOM", "FLAG"],
219
+ "code": ["ITEMID"],
220
+ "value": ["VALUENUM"],
221
+ "itemid": [50885],
222
+ },
223
+ "platelets": {
224
+ "fname": "LABEVENTS" + self.ext,
225
+ "timestamp": "CHARTTIME",
226
+ "timeoffsetunit": "abs",
227
+ "exclude": ["ROW_ID", "SUBJECT_ID", "VALUE", "VALUEUOM", "FLAG"],
228
+ "code": ["ITEMID"],
229
+ "value": ["VALUENUM"],
230
+ "itemid": [51265],
231
+ },
232
+ "wbc": {
233
+ "fname": "LABEVENTS" + self.ext,
234
+ "timestamp": "CHARTTIME",
235
+ "timeoffsetunit": "abs",
236
+ "exclude": ["ROW_ID", "SUBJECT_ID", "VALUE", "FLAG"],
237
+ "code": ["ITEMID"],
238
+ "value": ["VALUENUM"],
239
+ "itemid": [51300, 51301],
240
+ },
241
+ "dialysis": {
242
+ "tables": {
243
+ "chartevents": {
244
+ "fname": "CHARTEVENTS" + self.ext,
245
+ "timestamp": "CHARTTIME",
246
+ "timeoffsetunit": "abs",
247
+ "include": [
248
+ "ICUSTAY_ID",
249
+ "SUBJECT_ID",
250
+ "ITEMID",
251
+ "VALUE",
252
+ "VALUENUM",
253
+ "CHARTTIME",
254
+ "ERROR",
255
+ ],
256
+ "itemid": {
257
+ "cv_ce": [
258
+ 152,
259
+ 148,
260
+ 149,
261
+ 146,
262
+ 147,
263
+ 151,
264
+ 150,
265
+ 7949,
266
+ 229,
267
+ 235,
268
+ 241,
269
+ 247,
270
+ 253,
271
+ 259,
272
+ 265,
273
+ 271,
274
+ 582,
275
+ 466,
276
+ 917,
277
+ 927,
278
+ 6250,
279
+ ],
280
+ "mv_ce": [
281
+ 226118,
282
+ 227357,
283
+ 225725,
284
+ 226499,
285
+ 224154,
286
+ 225810,
287
+ 227639,
288
+ 225183,
289
+ 227438,
290
+ 224191,
291
+ 225806,
292
+ 225807,
293
+ 228004,
294
+ 228005,
295
+ 228006,
296
+ 224144,
297
+ 224145,
298
+ 224149,
299
+ 224150,
300
+ 224151,
301
+ 224152,
302
+ 224153,
303
+ 224404,
304
+ 224406,
305
+ 226457,
306
+ 225959,
307
+ 224135,
308
+ 224139,
309
+ 224146,
310
+ 225323,
311
+ 225740,
312
+ 225776,
313
+ 225951,
314
+ 225952,
315
+ 225953,
316
+ 225954,
317
+ 225956,
318
+ 225958,
319
+ 225961,
320
+ 225963,
321
+ 225965,
322
+ 225976,
323
+ 225977,
324
+ 227124,
325
+ 227290,
326
+ 227638,
327
+ 227640,
328
+ 227753,
329
+ ],
330
+ },
331
+ },
332
+ "inputevents_cv": {
333
+ "fname": "INPUTEVENTS_CV" + self.ext,
334
+ "timestamp": "CHARTTIME",
335
+ "timeoffsetunit": "abs",
336
+ "include": ["SUBJECT_ID", "ITEMID", "AMOUNT", "CHARTTIME"],
337
+ "itemid": {
338
+ "cv_ie": [
339
+ 40788,
340
+ 40907,
341
+ 41063,
342
+ 41147,
343
+ 41307,
344
+ 41460,
345
+ 41620,
346
+ 41711,
347
+ 41791,
348
+ 41792,
349
+ 42562,
350
+ 43829,
351
+ 44037,
352
+ 44188,
353
+ 44526,
354
+ 44527,
355
+ 44584,
356
+ 44591,
357
+ 44698,
358
+ 44927,
359
+ 44954,
360
+ 45157,
361
+ 45268,
362
+ 45352,
363
+ 45353,
364
+ 46012,
365
+ 46013,
366
+ 46172,
367
+ 46173,
368
+ 46250,
369
+ 46262,
370
+ 46292,
371
+ 46293,
372
+ 46311,
373
+ 46389,
374
+ 46574,
375
+ 46681,
376
+ 46720,
377
+ 46769,
378
+ 46773,
379
+ ]
380
+ },
381
+ },
382
+ "outputevents": {
383
+ "fname": "OUTPUTEVENTS" + self.ext,
384
+ "timestamp": "CHARTTIME",
385
+ "timeoffsetunit": "abs",
386
+ "include": ["SUBJECT_ID", "ITEMID", "VALUE", "CHARTTIME"],
387
+ "itemid": {
388
+ "cv_oe": [
389
+ 40386,
390
+ 40425,
391
+ 40426,
392
+ 40507,
393
+ 40613,
394
+ 40624,
395
+ 40690,
396
+ 40745,
397
+ 40789,
398
+ 40881,
399
+ 40910,
400
+ 41016,
401
+ 41034,
402
+ 41069,
403
+ 41112,
404
+ 41250,
405
+ 41374,
406
+ 41417,
407
+ 41500,
408
+ 41527,
409
+ 41623,
410
+ 41635,
411
+ 41713,
412
+ 41750,
413
+ 41829,
414
+ 41842,
415
+ 41897,
416
+ 42289,
417
+ 42388,
418
+ 42464,
419
+ 42524,
420
+ 42536,
421
+ 42868,
422
+ 42928,
423
+ 42972,
424
+ 43016,
425
+ 43052,
426
+ 43098,
427
+ 43115,
428
+ 43687,
429
+ 43941,
430
+ 44027,
431
+ 44085,
432
+ 44193,
433
+ 44199,
434
+ 44216,
435
+ 44286,
436
+ 44567,
437
+ 44843,
438
+ 44845,
439
+ 44857,
440
+ 44901,
441
+ 44943,
442
+ 45479,
443
+ 45828,
444
+ 46230,
445
+ 46232,
446
+ 46394,
447
+ 46464,
448
+ 46712,
449
+ 46713,
450
+ 46715,
451
+ 46741,
452
+ ],
453
+ },
454
+ },
455
+ "inputevents_mv": {
456
+ "fname": "INPUTEVENTS_MV" + self.ext,
457
+ "timestamp": "STARTTIME",
458
+ "timeoffsetunit": "abs",
459
+ "include": ["SUBJECT_ID", "ITEMID", "AMOUNT", "STARTTIME"],
460
+ "itemid": {"mv_ie": [227536, 227525]},
461
+ },
462
+ "datetimeevents": {
463
+ "fname": "DATETIMEEVENTS" + self.ext,
464
+ "timestamp": "CHARTTIME",
465
+ "timeoffsetunit": "abs",
466
+ "include": ["SUBJECT_ID", "ITEMID", "CHARTTIME"],
467
+ "itemid": {"mv_de": [225318, 225319, 225321, 225322, 225324]},
468
+ },
469
+ "procedureevents_mv": {
470
+ "fname": "PROCEDUREEVENTS_MV" + self.ext,
471
+ "timestamp": "STARTTIME",
472
+ "timeoffsetunit": "abs",
473
+ "include": ["SUBJECT_ID", "ITEMID", "STARTTIME"],
474
+ "itemid": {
475
+ "mv_pe": [225441, 225802, 225803, 225805, 224270, 225809, 225955, 225436]
476
+ },
477
+ },
478
+ }
479
+ },
480
+ }
481
+
482
+ self.disch_map_dict = {
483
+ "DISC-TRAN CANCER/CHLDRN H": "Other",
484
+ "DISC-TRAN TO FEDERAL HC": "Other",
485
+ "DISCH-TRAN TO PSYCH HOSP": "Other",
486
+ "HOME": "Home",
487
+ "HOME HEALTH CARE": "Home",
488
+ "HOME WITH HOME IV PROVIDR": "Home",
489
+ "HOSPICE-HOME": "Other",
490
+ "HOSPICE-MEDICAL FACILITY": "Other",
491
+ "ICF": "Other",
492
+ "IN_ICU_MORTALITY": "IN_ICU_MORTALITY",
493
+ "LEFT AGAINST MEDICAL ADVI": "Other",
494
+ "LONG TERM CARE HOSPITAL": "Other",
495
+ "OTHER FACILITY": "Other",
496
+ "REHAB/DISTINCT PART HOSP": "Rehabilitation",
497
+ "SHORT TERM HOSPITAL": "Other",
498
+ "SNF": "Skilled Nursing Facility",
499
+ "SNF-MEDICAID ONLY CERTIF": "Skilled Nursing Facility",
500
+ "Death": "Death",
501
+ }
502
+ self._icustay_key = "ICUSTAY_ID"
503
+ self._hadm_key = "HADM_ID"
504
+ self._patient_key = "SUBJECT_ID"
505
+
506
+ self._determine_first_icu = "INTIME"
507
+
508
+ def build_cohorts(self, cached=False):
509
+ icustays = pd.read_csv(os.path.join(self.data_dir, self.icustay_fname))
510
+
511
+ icustays = self.make_compatible(icustays)
512
+ self.icustays = icustays
513
+
514
+ cohorts = super().build_cohorts(icustays, cached=cached)
515
+
516
+ return cohorts
517
+
518
+ def prepare_tasks(self, cohorts, spark, cached=False):
519
+ if cohorts is None and cached:
520
+ labeled_cohorts = self.load_from_cache(self.ehr_name + ".cohorts.labeled")
521
+ if labeled_cohorts is not None:
522
+ return labeled_cohorts
523
+
524
+ labeled_cohorts = super().prepare_tasks(cohorts, spark, cached)
525
+
526
+ if self.diagnosis:
527
+ logger.info("Start labeling cohorts for diagnosis prediction.")
528
+ # define diagnosis prediction task
529
+ diagnoses = pd.read_csv(os.path.join(self.data_dir, self.diagnosis_fname))
530
+
531
+ ccs_dx = pd.read_csv(self.ccs_path)
532
+ ccs_dx["'ICD-9-CM CODE'"] = ccs_dx["'ICD-9-CM CODE'"].str[1:-1].str.strip()
533
+ ccs_dx["'CCS LVL 1'"] = ccs_dx["'CCS LVL 1'"].str[1:-1]
534
+ lvl1 = {x: int(y) - 1 for _, (x, y) in ccs_dx[["'ICD-9-CM CODE'", "'CCS LVL 1'"]].iterrows()}
535
+ diagnoses["diagnosis"] = diagnoses["ICD9_CODE"].map(lvl1)
536
+
537
+ diagnoses = diagnoses[(diagnoses["diagnosis"].notnull()) & (diagnoses["diagnosis"] != 14)]
538
+ diagnoses.loc[diagnoses["diagnosis"] >= 14, "diagnosis"] -= 1
539
+ diagnoses = diagnoses.groupby(self.hadm_key)["diagnosis"].agg(lambda x: list(set(x))).to_frame()
540
+ labeled_cohorts = labeled_cohorts.merge(diagnoses, on=self.hadm_key, how="inner")
541
+
542
+ # labeled_cohorts = labeled_cohorts.drop(columns=["ICD9_CODE"])
543
+
544
+ logger.info(
545
+ "Done preparing diagnosis prediction for the given cohorts, Cohort Numbers: {}".format(
546
+ len(labeled_cohorts)
547
+ )
548
+ )
549
+
550
+ self.save_to_cache(labeled_cohorts, self.ehr_name + ".cohorts.labeled")
551
+
552
+ if self.bilirubin or self.platelets or self.creatinine or self.wbc:
553
+ logger.info("Start labeling cohorts for clinical task prediction.")
554
+
555
+ labeled_cohorts = spark.createDataFrame(labeled_cohorts)
556
+
557
+ if self.bilirubin:
558
+ labeled_cohorts = self.clinical_task(labeled_cohorts, "bilirubin", spark)
559
+
560
+ if self.platelets:
561
+ labeled_cohorts = self.clinical_task(labeled_cohorts, "platelets", spark)
562
+
563
+ if self.creatinine:
564
+ labeled_cohorts = self.clinical_task(labeled_cohorts, "creatinine", spark)
565
+
566
+ if self.wbc:
567
+ labeled_cohorts = self.clinical_task(labeled_cohorts, "wbc", spark)
568
+ # labeled_cohorts = labeled_cohorts.toPandas()
569
+
570
+ # self.save_to_cache(labeled_cohorts, self.ehr_name + ".cohorts.labeled.clinical_tasks")
571
+
572
+ logger.info("Done preparing clinical task prediction for the given cohorts")
573
+
574
+ # self.save_to_cache(labeled_cohorts, self.ehr_name + ".cohorts.labeled")
575
+ return labeled_cohorts
576
+
577
+ def make_compatible(self, icustays):
578
+ patients = pd.read_csv(os.path.join(self.data_dir, self.patient_fname))
579
+ admissions = pd.read_csv(os.path.join(self.data_dir, self.admission_fname))
580
+
581
+ # prepare icustays according to the appropriate format
582
+ icustays = icustays[icustays["FIRST_CAREUNIT"] == icustays["LAST_CAREUNIT"]]
583
+
584
+ icustays["INTIME"] = pd.to_datetime(icustays["INTIME"], infer_datetime_format=True, utc=True)
585
+
586
+ icustays["OUTTIME"] = pd.to_datetime(icustays["OUTTIME"], infer_datetime_format=True, utc=True)
587
+ icustays = icustays.drop(columns=["ROW_ID"])
588
+
589
+ # merge icustays with patients to get DOB
590
+ patients["DOB"] = pd.to_datetime(patients["DOB"], infer_datetime_format=True, utc=True)
591
+ patients = patients[patients["SUBJECT_ID"].isin(icustays["SUBJECT_ID"])]
592
+ patients = patients.drop(columns=["ROW_ID"])[["DOB", "SUBJECT_ID"]]
593
+ icustays = icustays.merge(patients, on="SUBJECT_ID", how="left")
594
+
595
+ def calculate_age(birth: datetime, now: datetime):
596
+ age = now.year - birth.year
597
+ if now.month < birth.month:
598
+ age -= 1
599
+ elif (now.month == birth.month) and (now.day < birth.day):
600
+ age -= 1
601
+
602
+ return age
603
+
604
+ icustays["AGE"] = icustays.apply(lambda x: calculate_age(x["DOB"], x["INTIME"]), axis=1)
605
+
606
+ # merge with admissions to get discharge information
607
+ icustays = pd.merge(
608
+ icustays.reset_index(drop=True),
609
+ admissions[["HADM_ID", "DISCHARGE_LOCATION", "DEATHTIME", "DISCHTIME"]],
610
+ how="left",
611
+ on="HADM_ID",
612
+ )
613
+ icustays["DISCHARGE_LOCATION"].replace("DEAD/EXPIRED", "Death", inplace=True)
614
+
615
+ icustays["DISCHTIME"] = pd.to_datetime(icustays["DISCHTIME"], infer_datetime_format=True, utc=True)
616
+
617
+ icustays["IN_ICU_MORTALITY"] = (
618
+ (icustays["INTIME"] < icustays["DISCHTIME"])
619
+ & (icustays["DISCHTIME"] <= icustays["OUTTIME"])
620
+ & (icustays["DISCHARGE_LOCATION"] == "Death")
621
+ )
622
+ icustays["DISCHARGE_LOCATION"] = icustays["DISCHARGE_LOCATION"].map(self.disch_map_dict)
623
+
624
+ icustays.rename(columns={"DISCHARGE_LOCATION": "HOS_DISCHARGE_LOCATION"}, inplace=True)
625
+ icustays["DISCHTIME"] = (icustays["DISCHTIME"] - icustays["INTIME"]).dt.total_seconds() // 60
626
+ icustays["OUTTIME"] = (icustays["OUTTIME"] - icustays["INTIME"]).dt.total_seconds() // 60
627
+ return icustays
628
+
629
+ def clinical_task(self, cohorts, task, spark):
630
+ fname = self.task_itemids[task]["fname"]
631
+ timestamp = self.task_itemids[task]["timestamp"]
632
+ timeoffsetunit = self.task_itemids[task]["timeoffsetunit"]
633
+ excludes = self.task_itemids[task]["exclude"]
634
+ code = self.task_itemids[task]["code"][0]
635
+ value = self.task_itemids[task]["value"][0]
636
+ itemid = self.task_itemids[task]["itemid"]
637
+
638
+ table = spark.read.csv(os.path.join(self.data_dir, fname), header=True)
639
+ table = table.drop(*excludes)
640
+ table = table.filter(F.col(code).isin(itemid)).filter(F.col(value).isNotNull())
641
+
642
+ merge = cohorts.join(table, on=self.hadm_key, how="inner")
643
+
644
+ if timeoffsetunit == "abs":
645
+ merge = merge.withColumn(timestamp, F.to_timestamp(timestamp))
646
+
647
+ # For Creatinine task, eliminate icus if patient went through dialysis treatment
648
+ # before (obs_size + pred_size) timestamp
649
+ # see https://github.com/MIT-LCP/mimic-code/blob/main/mimic-iii/concepts/rrt.sql
650
+ if task == "creatinine":
651
+ dialysis_tables = self.task_itemids["dialysis"]["tables"]
652
+ chartevents = spark.read.csv(os.path.join(self.data_dir, "CHARTEVENTS" + self.ext), header=True)
653
+ inputevents_cv = spark.read.csv(
654
+ os.path.join(self.data_dir, "INPUTEVENTS_CV" + self.ext), header=True
655
+ )
656
+ outputevents = spark.read.csv(os.path.join(self.data_dir, "OUTPUTEVENTS" + self.ext), header=True)
657
+ inputevents_mv = spark.read.csv(
658
+ os.path.join(self.data_dir, "INPUTEVENTS_MV" + self.ext), header=True
659
+ )
660
+ datetimeevents = spark.read.csv(
661
+ os.path.join(self.data_dir, "DATETIMEEVENTS" + self.ext), header=True
662
+ )
663
+ procedureevents_mv = spark.read.csv(
664
+ os.path.join(self.data_dir, "PROCEDUREEVENTS_MV" + self.ext), header=True
665
+ )
666
+ icustays = spark.read.csv(os.path.join(self.data_dir, "ICUSTAYS" + self.ext), header=True)
667
+
668
+ chartevents = chartevents.select(*dialysis_tables["chartevents"]["include"])
669
+ inputevents_cv = inputevents_cv.select(*dialysis_tables["inputevents_cv"]["include"])
670
+ outputevents = outputevents.select(*dialysis_tables["outputevents"]["include"])
671
+ inputevents_mv = inputevents_mv.select(*dialysis_tables["inputevents_mv"]["include"])
672
+ datetimeevents = datetimeevents.select(*dialysis_tables["datetimeevents"]["include"])
673
+ procedureevents_mv = procedureevents_mv.select(*dialysis_tables["procedureevents_mv"]["include"])
674
+
675
+ # Filter dialysis related tables with dialysis condition #TODO: check dialysis condition
676
+ cv_ce = (
677
+ chartevents.filter(F.col("ITEMID").isin(dialysis_tables["chartevents"]["itemid"]["cv_ce"]))
678
+ .filter(F.col("VALUE").isNotNull())
679
+ .filter((F.col("ERROR").isNull()) | (F.col("ERROR") == 0))
680
+ .filter(
681
+ (
682
+ (F.col("ITEMID").isin([152, 148, 149, 146, 147, 151, 150]))
683
+ & (F.col("VALUE").isNotNull())
684
+ )
685
+ | (
686
+ (F.col("ITEMID").isin([229, 235, 241, 247, 253, 259, 265, 271]))
687
+ & (F.col("VALUE") == "Dialysis Line")
688
+ )
689
+ | ((F.col("ITEMID") == 466) & (F.col("VALUE") == "Dialysis RN"))
690
+ | ((F.col("ITEMID") == 927) & (F.col("VALUE") == "Dialysis Solutions"))
691
+ | ((F.col("ITEMID") == 6250) & (F.col("VALUE") == "dialys"))
692
+ | (
693
+ (F.col("ITEMID") == 917)
694
+ & (
695
+ F.col("VALUE").isin(
696
+ [
697
+ "+ INITIATE DIALYSIS",
698
+ "BLEEDING FROM DIALYSIS CATHETER",
699
+ "FAILED DIALYSIS CATH.",
700
+ "FEBRILE SYNDROME;DIALYSIS",
701
+ "HYPOTENSION WITH HEMODIALYSIS",
702
+ "HYPOTENSION.GLOGGED DIALYSIS",
703
+ "INFECTED DIALYSIS CATHETER",
704
+ ]
705
+ )
706
+ )
707
+ )
708
+ | (
709
+ (F.col("ITEMID") == 582)
710
+ & (
711
+ F.col("VALUE").isin(
712
+ [
713
+ "CAVH Start",
714
+ "CAVH D/C",
715
+ "CVVHD Start",
716
+ "CVVHD D/C",
717
+ "Hemodialysis st",
718
+ "Hemodialysis end",
719
+ ]
720
+ )
721
+ )
722
+ )
723
+ )
724
+ )
725
+ icustays = icustays.filter(F.col("DBSOURCE") == "carevue").select(self.icustay_key)
726
+ cv_ce = cv_ce.join(icustays, on=self.icustay_key, how="inner")
727
+ cv_ie = inputevents_cv.filter(
728
+ F.col("ITEMID").isin(dialysis_tables["inputevents_cv"]["itemid"]["cv_ie"])
729
+ ).filter(F.col("AMOUNT") > 0)
730
+ cv_oe = outputevents.filter(
731
+ F.col("ITEMID").isin(dialysis_tables["outputevents"]["itemid"]["cv_oe"])
732
+ ).filter(F.col("VALUE") > 0)
733
+ mv_ce = (
734
+ chartevents.filter(F.col("ITEMID").isin(dialysis_tables["chartevents"]["itemid"]["mv_ce"]))
735
+ .filter(F.col("VALUENUM") > 0)
736
+ .filter((F.col("ERROR").isNull()) | (F.col("ERROR") == 0))
737
+ )
738
+ mv_ie = inputevents_mv.filter(
739
+ F.col("ITEMID").isin(dialysis_tables["inputevents_mv"]["itemid"]["mv_ie"])
740
+ ).filter(F.col("AMOUNT") > 0)
741
+ mv_de = datetimeevents.filter(
742
+ F.col("ITEMID").isin(dialysis_tables["datetimeevents"]["itemid"]["mv_de"])
743
+ )
744
+ mv_pe = procedureevents_mv.filter(
745
+ F.col("ITEMID").isin(dialysis_tables["procedureevents_mv"]["itemid"]["mv_pe"])
746
+ )
747
+
748
+ def dialysis_time(table, timecolumn):
749
+ return table.withColumn("_DIALYSIS_TIME", F.to_timestamp(timecolumn)).select(
750
+ self.patient_key, "_DIALYSIS_TIME"
751
+ )
752
+
753
+ cv_ce, cv_ie, cv_oe, mv_ce, mv_ie, mv_de, mv_pe = (
754
+ dialysis_time(cv_ce, "CHARTTIME"),
755
+ dialysis_time(cv_ie, "CHARTTIME"),
756
+ dialysis_time(cv_oe, "CHARTTIME"),
757
+ dialysis_time(mv_ce, "CHARTTIME"),
758
+ dialysis_time(mv_ie, "STARTTIME"),
759
+ dialysis_time(mv_de, "CHARTTIME"),
760
+ dialysis_time(mv_pe, "STARTTIME"),
761
+ )
762
+
763
+ dialysis = cv_ce.union(cv_ie).union(cv_oe).union(mv_ce).union(mv_ie).union(mv_de).union(mv_pe)
764
+ dialysis = dialysis.groupBy(self.patient_key).agg(F.min("_DIALYSIS_TIME").alias("_DIALYSIS_TIME"))
765
+ merge = merge.join(dialysis, on=self.patient_key, how="left")
766
+ merge = merge.filter(F.isnull("_DIALYSIS_TIME") | (F.col("_DIALYSIS_TIME") > F.col(timestamp)))
767
+ merge = merge.drop("_DIALYSIS_TIME")
768
+
769
+ if timeoffsetunit == "abs":
770
+ merge = merge.withColumn(
771
+ timestamp, F.round((F.col(timestamp).cast("long") - F.col("INTIME").cast("long")) / 60)
772
+ )
773
+
774
+ # Cohort with events within (obs_size + gap_size) - (obs_size + pred_size)
775
+ merge = merge.filter(((self.obs_size + self.gap_size) * 60) <= F.col(timestamp)).filter(
776
+ ((self.obs_size + self.pred_size) * 60) >= F.col(timestamp)
777
+ )
778
+
779
+ # Average value of events
780
+ value_agg = merge.groupBy(self.icustay_key).agg(
781
+ F.mean(value).alias("avg_value")
782
+ ) # TODO: mean/min/max?
783
+
784
+ # Labeling
785
+ if task == "bilirubin":
786
+ value_agg = value_agg.withColumn(
787
+ task,
788
+ F.when(value_agg.avg_value < 1.2, 0)
789
+ .when((value_agg.avg_value >= 1.2) & (value_agg.avg_value < 2.0), 1)
790
+ .when((value_agg.avg_value >= 2.0) & (value_agg.avg_value < 6.0), 2)
791
+ .when((value_agg.avg_value >= 6.0) & (value_agg.avg_value < 12.0), 3)
792
+ .when(value_agg.avg_value >= 12.0, 4),
793
+ )
794
+ elif task == "platelets":
795
+ value_agg = value_agg.withColumn(
796
+ task,
797
+ F.when(value_agg.avg_value >= 150, 0)
798
+ .when((value_agg.avg_value >= 100) & (value_agg.avg_value < 150), 1)
799
+ .when((value_agg.avg_value >= 50) & (value_agg.avg_value < 100), 2)
800
+ .when((value_agg.avg_value >= 20) & (value_agg.avg_value < 50), 3)
801
+ .when(value_agg.avg_value < 20, 4),
802
+ )
803
+
804
+ elif task == "creatinine":
805
+ value_agg = value_agg.withColumn(
806
+ task,
807
+ F.when(value_agg.avg_value < 1.2, 0)
808
+ .when((value_agg.avg_value >= 1.2) & (value_agg.avg_value < 2.0), 1)
809
+ .when((value_agg.avg_value >= 2.0) & (value_agg.avg_value < 3.5), 2)
810
+ .when((value_agg.avg_value >= 3.5) & (value_agg.avg_value < 5), 3)
811
+ .when(value_agg.avg_value >= 5, 4),
812
+ )
813
+
814
+ elif task == "wbc":
815
+ value_agg = value_agg.withColumn(
816
+ task,
817
+ F.when(value_agg.avg_value < 4, 0)
818
+ .when((value_agg.avg_value >= 4) & (value_agg.avg_value <= 12), 1)
819
+ .when((value_agg.avg_value > 12), 2),
820
+ )
821
+
822
+ cohorts = cohorts.join(value_agg.select(self.icustay_key, task), on=self.icustay_key, how="left")
823
+
824
+ return cohorts
825
+
826
+ def infer_data_extension(self) -> str:
827
+ if len(glob.glob(os.path.join(self.data_dir, "*.csv.gz"))) == 26:
828
+ ext = ".csv.gz"
829
+ elif len(glob.glob(os.path.join(self.data_dir, "*.csv"))) == 26:
830
+ ext = ".csv"
831
+ else:
832
+ raise AssertionError(
833
+ "Provided data directory is not correct. Please check if --data is correct. "
834
+ "--data: {}".format(self.data_dir)
835
+ )
836
+
837
+ logger.info("Data extension is set to '{}'".format(ext))
838
+
839
+ return ext