genhpf 1.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. genhpf/__init__.py +9 -0
  2. genhpf/configs/__init__.py +23 -0
  3. genhpf/configs/config.yaml +8 -0
  4. genhpf/configs/configs.py +240 -0
  5. genhpf/configs/constants.py +29 -0
  6. genhpf/configs/initialize.py +58 -0
  7. genhpf/configs/utils.py +29 -0
  8. genhpf/criterions/__init__.py +74 -0
  9. genhpf/criterions/binary_cross_entropy.py +114 -0
  10. genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
  11. genhpf/criterions/criterion.py +87 -0
  12. genhpf/criterions/cross_entropy.py +202 -0
  13. genhpf/criterions/multi_task_criterion.py +177 -0
  14. genhpf/criterions/simclr_criterion.py +84 -0
  15. genhpf/criterions/wav2vec2_criterion.py +130 -0
  16. genhpf/datasets/__init__.py +84 -0
  17. genhpf/datasets/dataset.py +109 -0
  18. genhpf/datasets/genhpf_dataset.py +451 -0
  19. genhpf/datasets/meds_dataset.py +232 -0
  20. genhpf/loggings/__init__.py +0 -0
  21. genhpf/loggings/meters.py +374 -0
  22. genhpf/loggings/metrics.py +155 -0
  23. genhpf/loggings/progress_bar.py +445 -0
  24. genhpf/models/__init__.py +73 -0
  25. genhpf/models/genhpf.py +244 -0
  26. genhpf/models/genhpf_mlm.py +64 -0
  27. genhpf/models/genhpf_predictor.py +73 -0
  28. genhpf/models/genhpf_simclr.py +58 -0
  29. genhpf/models/genhpf_wav2vec2.py +304 -0
  30. genhpf/modules/__init__.py +15 -0
  31. genhpf/modules/gather_layer.py +23 -0
  32. genhpf/modules/grad_multiply.py +12 -0
  33. genhpf/modules/gumbel_vector_quantizer.py +204 -0
  34. genhpf/modules/identity_layer.py +8 -0
  35. genhpf/modules/layer_norm.py +27 -0
  36. genhpf/modules/positional_encoding.py +24 -0
  37. genhpf/scripts/__init__.py +0 -0
  38. genhpf/scripts/preprocess/__init__.py +0 -0
  39. genhpf/scripts/preprocess/genhpf/README.md +75 -0
  40. genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
  41. genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
  42. genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
  43. genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
  44. genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
  45. genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
  46. genhpf/scripts/preprocess/genhpf/main.py +175 -0
  47. genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
  48. genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
  49. genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
  50. genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
  51. genhpf/scripts/preprocess/manifest.py +83 -0
  52. genhpf/scripts/preprocess/preprocess_meds.py +674 -0
  53. genhpf/scripts/test.py +264 -0
  54. genhpf/scripts/train.py +365 -0
  55. genhpf/trainer.py +370 -0
  56. genhpf/utils/checkpoint_utils.py +171 -0
  57. genhpf/utils/data_utils.py +130 -0
  58. genhpf/utils/distributed_utils.py +497 -0
  59. genhpf/utils/file_io.py +170 -0
  60. genhpf/utils/pdb.py +38 -0
  61. genhpf/utils/utils.py +204 -0
  62. genhpf-1.0.11.dist-info/LICENSE +21 -0
  63. genhpf-1.0.11.dist-info/METADATA +202 -0
  64. genhpf-1.0.11.dist-info/RECORD +67 -0
  65. genhpf-1.0.11.dist-info/WHEEL +5 -0
  66. genhpf-1.0.11.dist-info/entry_points.txt +6 -0
  67. genhpf-1.0.11.dist-info/top_level.txt +1 -0
@@ -0,0 +1,919 @@
1
+ import logging
2
+ import os
3
+ import pickle
4
+ import re
5
+ import shutil
6
+ import subprocess
7
+ import sys
8
+ from functools import reduce
9
+ from itertools import chain
10
+ from typing import List, Union
11
+
12
+ import h5py
13
+ import numpy as np
14
+ import pandas as pd
15
+ import pyspark.sql.functions as F
16
+ from pyspark.sql.types import ArrayType, IntegerType, StructField, StructType
17
+ from tqdm import tqdm
18
+ from transformers import AutoTokenizer
19
+ from utils import col_name_add, q_cut
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class EHR(object):
25
+ def __init__(self, cfg):
26
+ self.cfg = cfg
27
+
28
+ self.cache = cfg.cache
29
+
30
+ cache_dir = os.path.expanduser("~/.cache/ehr")
31
+ # cache_dir = self.cfg.dest
32
+ if not os.path.exists(cache_dir):
33
+ os.makedirs(cache_dir)
34
+ self.cache_dir = cache_dir
35
+
36
+ if self.cache:
37
+ logger.warning(
38
+ "--cache is set to True. Note that it forces to load cached"
39
+ f" data from {cache_dir},"
40
+ " which may ignore some arguments such as --first_icu, as well as task-related"
41
+ " arguments (--mortality, --los_3day, etc.)"
42
+ " If you want to avoid this, do not set --cache to True."
43
+ )
44
+
45
+ self.data_dir = cfg.data
46
+ self.ccs_path = cfg.ccs
47
+ self.gem_path = cfg.gem
48
+ self.ext = cfg.ext
49
+
50
+ self.max_event_size = cfg.max_event_size if cfg.max_event_size is not None else sys.maxsize
51
+ self.min_event_size = cfg.min_event_size if cfg.min_event_size is not None else 1
52
+ assert self.min_event_size > 0, (
53
+ "--min_event_size could not be negative or zero",
54
+ self.min_event_size,
55
+ )
56
+ assert self.min_event_size <= self.max_event_size, (
57
+ self.min_event_size,
58
+ self.max_event_size,
59
+ )
60
+
61
+ self.max_event_token_len = cfg.max_event_token_len
62
+ self.max_patient_token_len = cfg.max_patient_token_len
63
+
64
+ self.max_age = cfg.max_age if cfg.max_age is not None else sys.maxsize
65
+ self.min_age = cfg.min_age if cfg.min_age is not None else 0
66
+ assert self.min_age <= self.max_age, (self.min_age, self.max_age)
67
+
68
+ self.obs_size = cfg.obs_size
69
+ self.gap_size = cfg.gap_size
70
+ self.pred_size = cfg.pred_size
71
+ self.long_term_pred_size = cfg.long_term_pred_size
72
+
73
+ self.first_icu = cfg.first_icu
74
+
75
+ # Emb_type / feature_select
76
+ self.emb_type = cfg.emb_type
77
+ self.feature = cfg.feature
78
+ self.bucket_num = cfg.bucket_num
79
+
80
+ # tasks
81
+ self.mortality = cfg.mortality
82
+ self.long_term_mortality = cfg.long_term_mortality
83
+ self.los_3day = cfg.los_3day
84
+ self.los_7day = cfg.los_7day
85
+ self.readmission = cfg.readmission
86
+ self.final_acuity = cfg.final_acuity
87
+ self.imminent_discharge = cfg.imminent_discharge
88
+ self.diagnosis = cfg.diagnosis
89
+ self.creatinine = cfg.creatinine
90
+ self.bilirubin = cfg.bilirubin
91
+ self.platelets = cfg.platelets
92
+ self.wbc = cfg.wbc
93
+
94
+ self.chunk_size = cfg.chunk_size
95
+
96
+ self.dest = cfg.dest
97
+
98
+ self.bins = cfg.bins
99
+
100
+ os.makedirs(os.path.join(self.dest, "data"), exist_ok=True)
101
+ os.makedirs(os.path.join(self.dest, "metadata"), exist_ok=True)
102
+
103
+ self.special_tokens_dict = dict()
104
+ self.max_special_tokens = 100
105
+
106
+ self.tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
107
+
108
+ if self.emb_type == "textbase":
109
+ self.cls_token_id = self.tokenizer.cls_token_id
110
+ self.sep_token_id = self.tokenizer.sep_token_id
111
+
112
+ elif self.emb_type == "codebase":
113
+ self.cls_token_id = 1
114
+ self.sep_token_id = 2
115
+
116
+ self.table_type_id = 1
117
+ self.column_type_id = 2
118
+ self.value_type_id = 3
119
+ self.timeint_type_id = 4
120
+ self.cls_type_id = 5
121
+ self.sep_type_id = 6
122
+
123
+ self.others_dpe_id = 0
124
+
125
+ self._icustay_fname = None
126
+ self._patient_fname = None
127
+ self._admission_fname = None
128
+ self._diagnosis_fname = None
129
+
130
+ self._icustay_key = None
131
+ self._hadm_key = None
132
+
133
+ @property
134
+ def icustay_fname(self):
135
+ return self._icustay_fname
136
+
137
+ @property
138
+ def patient_fname(self):
139
+ return self._patient_fname
140
+
141
+ @property
142
+ def admission_fname(self):
143
+ return self._admission_fname
144
+
145
+ @property
146
+ def diagnosis_fname(self):
147
+ return self._diagnosis_fname
148
+
149
+ @property
150
+ def icustay_key(self):
151
+ return self._icustay_key
152
+
153
+ @property
154
+ def hadm_key(self):
155
+ return self._hadm_key
156
+
157
+ @property
158
+ def patient_key(self):
159
+ return self._patient_key
160
+
161
+ @property
162
+ def determine_first_icu(self):
163
+ return self._determine_first_icu
164
+
165
+ @property
166
+ def num_special_tokens(self):
167
+ return len(self.special_tokens_dict)
168
+
169
+ def build_cohorts(self, icustays, cached=False):
170
+ if cached:
171
+ cohorts = self.load_from_cache(self.ehr_name + ".cohorts")
172
+ if cohorts is not None:
173
+ return cohorts
174
+
175
+ if not self.is_compatible(icustays):
176
+ raise AssertionError(
177
+ "{} do not have required columns to build cohorts.".format(self.icustay_fname)
178
+ + " Please make sure that dataframe for icustays is compatible with other ehrs."
179
+ )
180
+
181
+ logger.info("Start building cohorts for {}".format(self.ehr_name))
182
+ logger.info("Emb_type {}, Feature {}".format(self.emb_type, self.feature))
183
+ obs_size = self.obs_size
184
+ gap_size = self.gap_size
185
+
186
+ icustays = icustays[icustays["LOS"] >= (obs_size + gap_size) / 24]
187
+ icustays = icustays[(self.min_age <= icustays["AGE"]) & (icustays["AGE"] <= self.max_age)]
188
+
189
+ # we define labels for the readmission task in this step
190
+ # since it requires to observe each next icustays,
191
+ # which would have been excluded in the final cohorts
192
+ icustays.sort_values([self.hadm_key, self.icustay_key], inplace=True)
193
+
194
+ if self.readmission:
195
+ icustays["readmission"] = 1
196
+ icustays.loc[
197
+ icustays.groupby(self.hadm_key)[self.determine_first_icu].idxmax(), "readmission"
198
+ ] = 0
199
+ if self.first_icu:
200
+ icustays = icustays.loc[icustays.groupby(self.hadm_key)[self.determine_first_icu].idxmin()]
201
+
202
+ logger.info("cohorts have been built successfully. Loaded {} cohorts.".format(len(icustays)))
203
+
204
+ self.save_to_cache(icustays, self.ehr_name + ".cohorts")
205
+
206
+ return icustays
207
+
208
+ # TODO process specific tasks according to user choice?
209
+ def prepare_tasks(self, cohorts, spark, cached=False):
210
+ if cached:
211
+ labeled_cohorts = self.load_from_cache(self.ehr_name + ".cohorts.labeled")
212
+ if labeled_cohorts is not None:
213
+ return labeled_cohorts
214
+ else:
215
+ raise RuntimeError()
216
+
217
+ logger.info("Start labeling cohorts for predictive tasks.")
218
+
219
+ labeled_cohorts = cohorts[
220
+ [
221
+ self.hadm_key,
222
+ self.icustay_key,
223
+ self.patient_key,
224
+ "readmission",
225
+ "LOS",
226
+ "INTIME",
227
+ "OUTTIME",
228
+ "DISCHTIME",
229
+ "IN_ICU_MORTALITY",
230
+ "HOS_DISCHARGE_LOCATION",
231
+ ]
232
+ ].copy()
233
+
234
+ # mortality prediction
235
+ # if the discharge location of an icustay is 'Death'
236
+ # & intime + obs_size + gap_size <= dischtime <= intime + obs_size + pred_size
237
+ # it is assigned positive label on the mortality prediction
238
+
239
+ if self.mortality:
240
+ labeled_cohorts["mortality"] = (
241
+ (
242
+ (labeled_cohorts["IN_ICU_MORTALITY"] == "Death")
243
+ | (labeled_cohorts["HOS_DISCHARGE_LOCATION"] == "Death")
244
+ )
245
+ & (self.obs_size * 60 + self.gap_size * 60 < labeled_cohorts["DISCHTIME"])
246
+ & (labeled_cohorts["DISCHTIME"] <= self.obs_size * 60 + self.pred_size * 60)
247
+ ).astype(int)
248
+
249
+ if self.long_term_mortality:
250
+ labeled_cohorts["long_term_mortality"] = (
251
+ (
252
+ (labeled_cohorts["IN_ICU_MORTALITY"] == "Death")
253
+ | (labeled_cohorts["HOS_DISCHARGE_LOCATION"] == "Death")
254
+ )
255
+ & (self.obs_size * 60 + self.gap_size * 60 < labeled_cohorts["DISCHTIME"])
256
+ & (labeled_cohorts["DISCHTIME"] <= self.obs_size * 60 + self.long_term_pred_size * 60)
257
+ ).astype(int)
258
+
259
+ if self.los_3day:
260
+ labeled_cohorts["los_3day"] = (labeled_cohorts["LOS"] > 3).astype(int)
261
+ if self.los_7day:
262
+ labeled_cohorts["los_7day"] = (labeled_cohorts["LOS"] > 7).astype(int)
263
+
264
+ if self.final_acuity or self.imminent_discharge:
265
+ # if the discharge of 'Death' occurs in icu or hospital
266
+ # we retain these cases for the imminent discharge task
267
+ labeled_cohorts["IN_HOSPITAL_MORTALITY"] = (
268
+ (~labeled_cohorts["IN_ICU_MORTALITY"])
269
+ & (labeled_cohorts["HOS_DISCHARGE_LOCATION"] == "Death")
270
+ ).astype(int)
271
+
272
+ if self.final_acuity:
273
+ # define final acuity prediction task
274
+ labeled_cohorts["final_acuity"] = labeled_cohorts["HOS_DISCHARGE_LOCATION"]
275
+ labeled_cohorts.loc[
276
+ labeled_cohorts["IN_ICU_MORTALITY"] == 1, "final_acuity"
277
+ ] = "IN_ICU_MORTALITY"
278
+ labeled_cohorts.loc[
279
+ labeled_cohorts["IN_HOSPITAL_MORTALITY"] == 1, "final_acuity"
280
+ ] = "IN_HOSPITAL_MORTALITY"
281
+
282
+ with open(os.path.join(self.dest, "metadata/final_acuity_classes.tsv"), "w") as f:
283
+ for i, cat in enumerate(
284
+ labeled_cohorts["final_acuity"].astype("category").cat.categories
285
+ ):
286
+ print("{}\t{}".format(i, cat), file=f)
287
+ labeled_cohorts["final_acuity"] = labeled_cohorts["final_acuity"].astype("category").cat.codes
288
+ # Replace -1 with NaN
289
+ labeled_cohorts.loc[labeled_cohorts["final_acuity"] == -1, "final_acuity"] = np.nan
290
+ if self.imminent_discharge:
291
+ # define imminent discharge prediction task
292
+ is_discharged = (self.obs_size * 60 + self.gap_size * 60 <= labeled_cohorts["DISCHTIME"]) & (
293
+ labeled_cohorts["DISCHTIME"] <= self.obs_size * 60 + self.pred_size * 60
294
+ )
295
+ labeled_cohorts.loc[is_discharged, "imminent_discharge"] = labeled_cohorts.loc[
296
+ is_discharged, "HOS_DISCHARGE_LOCATION"
297
+ ]
298
+ labeled_cohorts.loc[
299
+ is_discharged
300
+ & (
301
+ (labeled_cohorts["IN_ICU_MORTALITY"] == 1)
302
+ | (labeled_cohorts["IN_HOSPITAL_MORTALITY"] == 1)
303
+ ),
304
+ "imminent_discharge",
305
+ ] = "Death"
306
+ labeled_cohorts.loc[~is_discharged, "imminent_discharge"] = "No Discharge"
307
+
308
+ with open(os.path.join(self.dest, "metadata/imminent_discharge_classes.tsv"), "w") as f:
309
+ for i, cat in enumerate(
310
+ labeled_cohorts["imminent_discharge"].astype("category").cat.categories
311
+ ):
312
+ print("{}\t{}".format(i, cat), file=f)
313
+ labeled_cohorts["imminent_discharge"] = (
314
+ labeled_cohorts["imminent_discharge"].astype("category").cat.codes
315
+ )
316
+ # Replace -1 with NaN
317
+ labeled_cohorts.loc[
318
+ labeled_cohorts["imminent_discharge"] == -1, "imminent_discharge"
319
+ ] = np.nan
320
+ labeled_cohorts = labeled_cohorts.drop(columns=["IN_HOSPITAL_MORTALITY"])
321
+
322
+ # clean up unnecessary columns
323
+ labeled_cohorts = labeled_cohorts.drop(
324
+ columns=["LOS", "IN_ICU_MORTALITY", "DISCHTIME", "HOS_DISCHARGE_LOCATION"]
325
+ )
326
+
327
+ self.save_to_cache(labeled_cohorts, self.ehr_name + ".cohorts.labeled")
328
+
329
+ logger.info("Done preparing tasks except for diagnosis prediction.")
330
+
331
+ return labeled_cohorts
332
+
333
+ def process_tables(self, cohorts, spark):
334
+ # in: cohorts, sparksession
335
+ # out: Spark DataFrame with (stay_id, time offset, inp, type, dpe)
336
+ if isinstance(cohorts, pd.DataFrame):
337
+ cohorts = cohorts[[self.hadm_key, self.icustay_key, "INTIME", "OUTTIME"]]
338
+ logger.info("Start Preprocessing Tables, Cohort Numbers: {}".format(len(cohorts)))
339
+ cohorts = spark.createDataFrame(cohorts)
340
+ print("Converted Cohort to Pyspark DataFrame")
341
+ else:
342
+ logger.info("Start Preprocessing Tables")
343
+
344
+ events_dfs = []
345
+ for table_index, table in enumerate(self.tables):
346
+ fname = table["fname"]
347
+ table_name = fname.split("/")[-1][: -len(self.ext)]
348
+
349
+ timestamp_key = table["timestamp"]
350
+ excludes = table["exclude"]
351
+ obs_size = self.obs_size
352
+ logger.info("{} in progress.".format(fname))
353
+
354
+ code_to_descriptions = None
355
+ if "code" in table:
356
+ code_to_descriptions = {
357
+ k: pd.read_csv(os.path.join(self.data_dir, v))
358
+ for k, v in zip(table["code"], table["desc"])
359
+ }
360
+ code_to_descriptions = {
361
+ k: dict(zip(v[k], v[d_k]))
362
+ for (k, v), d_k in zip(code_to_descriptions.items(), table["desc_key"])
363
+ }
364
+
365
+ infer_icustay_from_hadm_key = False
366
+
367
+ events = spark.read.csv(os.path.join(self.data_dir, fname), header=True)
368
+ if self.icustay_key not in events.columns:
369
+ infer_icustay_from_hadm_key = True
370
+ if self.hadm_key not in events.columns:
371
+ raise AssertionError(
372
+ "{} doesn't have one of these columns: {}".format(
373
+ fname, [self.icustay_key, self.hadm_key]
374
+ )
375
+ )
376
+
377
+ events = events.drop(*excludes)
378
+ if table["timeoffsetunit"] == "abs":
379
+ events = events.withColumn(timestamp_key, F.to_timestamp(timestamp_key))
380
+
381
+ if infer_icustay_from_hadm_key:
382
+ events = events.join(
383
+ cohorts.select(self.hadm_key, self.icustay_key, "INTIME", "OUTTIME"),
384
+ on=self.hadm_key,
385
+ how="inner",
386
+ )
387
+ if table["timeoffsetunit"] == "abs":
388
+ events = (
389
+ events.withColumn(
390
+ "TEMP_TIME",
391
+ F.round((F.col(timestamp_key).cast("long") - F.col("INTIME").cast("long")) / 60),
392
+ )
393
+ .filter(F.col("TEMP_TIME") >= 0)
394
+ .filter(F.col("TEMP_TIME") <= F.col("OUTTIME"))
395
+ .drop("TEMP_TIME")
396
+ )
397
+ else:
398
+ # All tables in eICU has icustay_key -> no need to handle
399
+ raise NotImplementedError()
400
+ events = events.join(cohorts.select(self.icustay_key), on=self.icustay_key, how="leftsemi")
401
+
402
+ else:
403
+ events = events.join(
404
+ cohorts.select(self.icustay_key, "INTIME", "OUTTIME"), on=self.icustay_key, how="inner"
405
+ )
406
+
407
+ if table["timeoffsetunit"] == "abs":
408
+ events = events.withColumn(
409
+ "TIME", F.round((F.col(timestamp_key).cast("long") - F.col("INTIME").cast("long")) / 60)
410
+ )
411
+ events = events.drop(timestamp_key)
412
+ elif table["timeoffsetunit"] == "min":
413
+ events = events.withColumn("TIME", F.col(timestamp_key).cast("int"))
414
+ events = events.drop(timestamp_key)
415
+ else:
416
+ raise NotImplementedError()
417
+
418
+ events = events.filter(F.col("TIME") >= 0).filter(F.col("TIME") <= obs_size * 60)
419
+ events = events.drop("INTIME", "OUTTIME", self.hadm_key)
420
+
421
+ if code_to_descriptions:
422
+ for col in code_to_descriptions.keys():
423
+ mapping_expr = F.create_map([F.lit(x) for x in chain(*code_to_descriptions[col].items())])
424
+ events = events.withColumn(col, mapping_expr[F.col(col)])
425
+
426
+ if self.emb_type == "codebase":
427
+ print("codebase pre-process starts")
428
+ events = events.toPandas()
429
+ print("Converted events Pyspark DataFrame to Pandas DataFrame")
430
+
431
+ code_col = table["code_feat"][0]
432
+ dtype_code = events[code_col].dtype
433
+
434
+ # Numericl feature bucketize
435
+ for numeric_col in table["numeric_feat"]:
436
+ if numeric_col in events.columns:
437
+ # numeric / not_numeric classifying
438
+ numeric = events[pd.to_numeric(events[numeric_col], errors="coerce").notnull()]
439
+ numeric = numeric.astype({numeric_col: "float"})
440
+ not_numeric = events[pd.to_numeric(events[numeric_col], errors="coerce").isnull()]
441
+
442
+ # buckettize
443
+ numeric.loc[:, numeric_col] = numeric.groupby(code_col)[numeric_col].transform(
444
+ lambda x: x.rank(method="dense")
445
+ )
446
+ numeric.loc[:, numeric_col] = numeric.groupby(code_col)[numeric_col].transform(
447
+ lambda x: q_cut(x, self.bucket_num)
448
+ )
449
+
450
+ numeric[numeric_col] = (
451
+ "B_" + numeric[code_col].astype("str") + "_" + numeric[numeric_col].astype("str")
452
+ )
453
+
454
+ events = pd.concat([numeric, not_numeric], axis=0)
455
+
456
+ # Categorical feature categorize
457
+ for cate_col in table["categorical_feat"]:
458
+ if cate_col in events.columns:
459
+ events.loc[:, cate_col] = events[cate_col].map(lambda x: col_name_add(x, cate_col))
460
+
461
+ events = events.astype({code_col: dtype_code})
462
+
463
+ table_feature_unique = []
464
+ for col in events.columns:
465
+ if col in [self.icustay_key, self.hadm_key, self.patient_key, "TIME"]:
466
+ continue
467
+ col_unique = list(events[col].unique())
468
+ table_feature_unique.extend(col_unique)
469
+
470
+ table_feature_unique = list(set(table_feature_unique))
471
+
472
+ if len(events_dfs) == 0:
473
+ max_idx = 3 + len(self.tables)
474
+ self.table_feature_dict = {k: idx + max_idx for idx, k in enumerate(table_feature_unique)}
475
+ else:
476
+ max_idx = max(self.table_feature_dict.values())
477
+ table_feature_unique = [
478
+ k for k in table_feature_unique if k not in self.table_feature_dict.keys()
479
+ ]
480
+ self.table_feature_dict.update(
481
+ {k: idx + max_idx for idx, k in enumerate(table_feature_unique)}
482
+ )
483
+
484
+ encoded_table_name = ([3 + table_index], [self.table_type_id], [0])
485
+ encoded_cols = {
486
+ k: ([max(self.table_feature_dict.values()) + 1 + i], [self.column_type_id], [0])
487
+ for i, k in enumerate(events.columns)
488
+ if k not in [self.icustay_key, self.hadm_key, self.patient_key, "TIME"]
489
+ }
490
+ for i, k in enumerate(events.columns):
491
+ if k not in [self.icustay_key, self.hadm_key, self.patient_key, "TIME"]:
492
+ self.table_feature_dict[k] = max(self.table_feature_dict.values()) + 1 + i
493
+
494
+ print("length of codebook = ", len(self.table_feature_dict))
495
+
496
+ events = events.fillna(" ")
497
+ events = spark.createDataFrame(events)
498
+ print("Converted Events DataFrame to Pyspark DataFrame")
499
+
500
+ def process_unit(feature, type_id):
501
+ input_ids = [self.table_feature_dict[feature]]
502
+ types = [type_id]
503
+ dpes = [0]
504
+ return input_ids, types, dpes
505
+
506
+ elif self.emb_type == "textbase":
507
+ print("textbase pre-process starts")
508
+
509
+ def process_unit(text, type_id):
510
+ # Given (table_name|col|val), generate ([inp], [type], [dpe])
511
+ text = re.sub(r"\d*\.\d+", lambda x: str(round(float(x.group(0)), 4)), str(text))
512
+ number_groups = [g for g in re.finditer(r"([0-9]+([.][0-9]*)?|[0-9]+|\.+)", text)]
513
+ text = re.sub(r"([0-9\.])", r" \1 ", text)
514
+ input_ids = self.tokenizer.encode(text, add_special_tokens=False)
515
+ types = [type_id] * len(input_ids)
516
+
517
+ def get_dpe(tokens, number_groups):
518
+ number_ids = [121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 119]
519
+ numbers = [i for i, j in enumerate(tokens) if j in number_ids]
520
+ numbers_cnt = 0
521
+ data_dpe = [0] * len(tokens)
522
+ for group in number_groups:
523
+ if group[0] == "." * len(group[0]):
524
+ numbers_cnt += len(group[0])
525
+ continue
526
+
527
+ start = numbers[numbers_cnt]
528
+ end = numbers[numbers_cnt + len(group[0]) - 1] + 1
529
+ corresponding_numbers = tokens[start:end]
530
+ digits = [i for i, j in enumerate(corresponding_numbers) if j == 119]
531
+
532
+ # Case Integer
533
+ if len(digits) == 0:
534
+ data_dpe[start:end] = list(range(len(group[0]) + 5, 5, -1))
535
+ # Case Float
536
+ elif len(digits) == 1:
537
+ digit_idx = len(group[0]) - digits[0]
538
+ data_dpe[start:end] = list(
539
+ range(len(group[0]) + 5 - digit_idx, 5 - digit_idx, -1)
540
+ )
541
+ else:
542
+ logger.warn(f"{data_dpe[start:end]} has irregular numerical formats")
543
+
544
+ numbers_cnt += len(group[0])
545
+ return data_dpe
546
+
547
+ dpes = get_dpe(input_ids, number_groups)
548
+ return input_ids, types, dpes
549
+
550
+ encoded_table_name = process_unit(table_name, self.table_type_id) # table name
551
+ encoded_cols = {k: process_unit(k, self.column_type_id) for k in events.columns} # table cols
552
+
553
+ schema = StructType(
554
+ [
555
+ StructField("INPUTS", ArrayType(IntegerType()), False),
556
+ StructField("TYPES", ArrayType(IntegerType()), False),
557
+ StructField("DPES", ArrayType(IntegerType()), False),
558
+ ]
559
+ )
560
+
561
+ # encoded_table_name -> tuple(Input, type, dpe)
562
+ # encoded_cols -> dict(col_name: input,type, dpe)
563
+ def process_row(encoded_table_name, encoded_cols):
564
+ def _process_row(row):
565
+ """
566
+ input: row (cols: icustay_id, timestamp, ...)
567
+ output: (input, type, dpe)
568
+ """
569
+ row = row.asDict()
570
+ # Should INITIALIZE with blank arrays to prevent corruption in Pyspark... Why??
571
+ input_ids, types, dpes = [], [], []
572
+ input_ids += encoded_table_name[0]
573
+ types += encoded_table_name[1]
574
+ dpes += encoded_table_name[2]
575
+ encoded_table_name
576
+ for col, val in row.items():
577
+ if col in [self.icustay_key, "TIME"] or val is None or val == " ":
578
+ continue
579
+ encoded_col = encoded_cols[col]
580
+ encoded_val = process_unit(val, self.value_type_id)
581
+
582
+ # -> values code화
583
+
584
+ if (
585
+ len(input_ids) + len(encoded_col[0]) + len(encoded_val[0]) + 2
586
+ <= self.max_event_token_len
587
+ ):
588
+ input_ids += encoded_col[0] + encoded_val[0]
589
+ types += encoded_col[1] + encoded_val[1]
590
+ dpes += encoded_col[2] + encoded_val[2]
591
+ else:
592
+ break
593
+ return input_ids, types, dpes
594
+
595
+ return F.udf(_process_row, returnType=schema)
596
+
597
+ events = (
598
+ events.withColumn(
599
+ "tmp", process_row(encoded_table_name, encoded_cols)(F.struct(*events.columns))
600
+ )
601
+ .withColumn("INPUTS", F.col("tmp.INPUTS"))
602
+ .withColumn("TYPES", F.col("tmp.TYPES"))
603
+ .withColumn("DPES", F.col("tmp.DPES"))
604
+ .select(self.icustay_key, "TIME", "INPUTS", "TYPES", "DPES") # column 을 얘네만 남김
605
+ ) # event process # -> INPUTS / TYPES / DPES /
606
+
607
+ events_dfs.append(events)
608
+
609
+ if self.emb_type == "codebase":
610
+ os.makedirs(os.path.join(self.cache_dir, self.ehr_name), exist_ok=True)
611
+ with open(
612
+ os.path.join(self.cache_dir, self.ehr_name, f"codebase_code2idx_{self.feature}.pkl"), "wb"
613
+ ) as f:
614
+ pickle.dump(self.table_feature_dict, f)
615
+
616
+ return reduce(lambda x, y: x.union(y), events_dfs)
617
+
618
+ def make_input(self, cohorts, events, spark):
619
+ @F.pandas_udf(returnType="TIME int", functionType=F.PandasUDFType.GROUPED_MAP)
620
+ def _make_input(events):
621
+ # Actually, this function does not have to return anything.
622
+ # However, return something(TIME) is required to satisfy the PySpark requirements.
623
+ df = events.sort_values("TIME")
624
+ flatten_cut_idx = -1
625
+ # Consider SEP
626
+ flatten_lens = np.cumsum(df["INPUTS"].str.len() + 1).values
627
+ event_length = len(df)
628
+
629
+ if flatten_lens[-1] > self.max_patient_token_len - 1:
630
+ # Consider CLS token at first of the flatten input
631
+ flatten_cut_idx = np.searchsorted(
632
+ flatten_lens, flatten_lens[-1] - self.max_patient_token_len + 1
633
+ )
634
+ flatten_lens = (flatten_lens - flatten_lens[flatten_cut_idx])[flatten_cut_idx + 1 :]
635
+ event_length = len(flatten_lens)
636
+
637
+ # Event length should not be longer than max_event_size
638
+ event_length = min(event_length, self.max_event_size)
639
+ df = df.iloc[-event_length:]
640
+
641
+ if len(df) <= self.min_event_size:
642
+ return events["TIME"].to_frame()
643
+
644
+ def make_hi(cls_id, sep_id, iterable):
645
+ return [[cls_id] + list(i) + [sep_id] for i in iterable]
646
+
647
+ def make_fl(cls_id, sep_id, iterable):
648
+ return [cls_id] + list(chain(*[list(i) + [sep_id] for i in iterable]))
649
+
650
+ hi_input = make_hi(self.cls_token_id, self.sep_token_id, df["INPUTS"])
651
+ hi_type = make_hi(self.cls_type_id, self.sep_type_id, df["TYPES"])
652
+ hi_dpe = make_hi(self.others_dpe_id, self.others_dpe_id, df["DPES"])
653
+
654
+ fl_input = make_fl(self.cls_token_id, self.sep_token_id, df["INPUTS"])
655
+ fl_type = make_fl(self.cls_type_id, self.sep_type_id, df["TYPES"])
656
+ fl_dpe = make_fl(self.others_dpe_id, self.others_dpe_id, df["DPES"])
657
+
658
+ assert len(hi_input) <= self.max_event_size, hi_input
659
+ assert all([len(i) <= self.max_event_token_len for i in hi_input]), hi_input
660
+ assert len(fl_input) <= self.max_patient_token_len, fl_input
661
+
662
+ # Add padding to save as numpy array
663
+ hi_input = np.array(
664
+ [np.pad(i, (0, self.max_event_token_len - len(i)), mode="constant") for i in hi_input]
665
+ )
666
+ hi_type = np.array(
667
+ [np.pad(i, (0, self.max_event_token_len - len(i)), mode="constant") for i in hi_type]
668
+ )
669
+ hi_dpe = np.array(
670
+ [np.pad(i, (0, self.max_event_token_len - len(i)), mode="constant") for i in hi_dpe]
671
+ )
672
+
673
+ fl_input = np.pad(fl_input, (0, self.max_patient_token_len - len(fl_input)), mode="constant")
674
+ fl_type = np.pad(fl_type, (0, self.max_patient_token_len - len(fl_type)), mode="constant")
675
+ fl_dpe = np.pad(fl_dpe, (0, self.max_patient_token_len - len(fl_dpe)), mode="constant")
676
+
677
+ stay_id = df[self.icustay_key].values[0]
678
+ # Create caches (cannot write to hdf5 directly with pyspark)
679
+ data = {
680
+ "hi": np.stack([hi_input, hi_type, hi_dpe], axis=1).astype(np.int16),
681
+ "fl": np.stack([fl_input, fl_type, fl_dpe], axis=0).astype(np.int16),
682
+ "time": df["TIME"].values,
683
+ }
684
+ with open(
685
+ os.path.join(self.cache_dir, self.ehr_name, self.emb_type, self.feature, f"{stay_id}.pkl"),
686
+ "wb",
687
+ ) as f:
688
+ pickle.dump(data, f)
689
+ return events["TIME"].to_frame()
690
+
691
+ os.makedirs(os.path.join(self.cache_dir, self.ehr_name, self.emb_type, self.feature), exist_ok=True)
692
+
693
+ events.groupBy(self.icustay_key).apply(_make_input).write.mode("overwrite").format("noop").save()
694
+
695
+ logger.info("Finish Data Preprocessing. Start to write to hdf5")
696
+
697
+ f = h5py.File(os.path.join(self.dest, "data/data.h5"), "w")
698
+ ehr_g = f.create_group("ehr")
699
+
700
+ active_stay_ids = []
701
+
702
+ for stay_id_file in tqdm(
703
+ os.listdir(os.path.join(self.cache_dir, self.ehr_name, self.emb_type, self.feature))
704
+ ): # 45080 / 45080
705
+ stay_id = stay_id_file.split(".")[0]
706
+ with open(
707
+ os.path.join(self.cache_dir, self.ehr_name, self.emb_type, self.feature, stay_id_file), "rb"
708
+ ) as f:
709
+ data = pickle.load(f)
710
+ stay_g = ehr_g.create_group(str(stay_id))
711
+ stay_g.create_dataset("hi", data=data["hi"], dtype="i2", compression="lzf", shuffle=True)
712
+ stay_g.create_dataset("fl", data=data["fl"], dtype="i2", compression="lzf", shuffle=True)
713
+ stay_g.create_dataset("time", data=data["time"], dtype="i")
714
+ active_stay_ids.append(int(stay_id))
715
+
716
+ shutil.rmtree(
717
+ os.path.join(self.cache_dir, self.ehr_name, self.emb_type, self.feature), ignore_errors=True
718
+ )
719
+ # Drop patients with few events
720
+
721
+ if not isinstance(cohorts, pd.DataFrame):
722
+ cohorts = cohorts.toPandas()
723
+ print(cohorts)
724
+
725
+ logger.info(
726
+ "Total {} patients in the cohort are skipped due to few events".format(
727
+ len(cohorts) - len(active_stay_ids)
728
+ )
729
+ )
730
+ cohorts = cohorts[cohorts[self.icustay_key].isin(active_stay_ids)]
731
+
732
+ # Should consider pat_id for split
733
+ # for seed in self.seed:
734
+ # shuffled = cohorts.groupby(
735
+ # self.patient_key
736
+ # )[self.patient_key].count().sample(frac=1, random_state=seed)
737
+ # cum_len = shuffled.cumsum()
738
+
739
+ # cohorts.loc[cohorts[self.patient_key].isin(
740
+ # shuffled[cum_len < int(sum(shuffled)*self.valid_percent)].index), f'split_{seed}'] = 'test'
741
+ # cohorts.loc[cohorts[self.patient_key].isin(
742
+ # shuffled[(cum_len >= int(sum(shuffled)*self.valid_percent))
743
+ # & (cum_len < int(sum(shuffled)*2*self.valid_percent))].index), f'split_{seed}'] = 'valid'
744
+ # cohorts.loc[
745
+ # cohorts[self.patient_key].isin(
746
+ # shuffled[cum_len >= int(sum(shuffled)*2*self.valid_percent)].index
747
+ # ), f'split_{seed}'
748
+ # ] = 'train'
749
+
750
+ # select columns stay_id and tasks
751
+ task_columns = [
752
+ "mortality",
753
+ "long_term_mortality",
754
+ "los_3day",
755
+ "los_7day",
756
+ "readmission",
757
+ "final_acuity",
758
+ "imminent_discharge",
759
+ ]
760
+ selected_columns = [self.icustay_key] + [col for col in task_columns if col in cohorts.columns]
761
+ cohorts = cohorts[selected_columns]
762
+ cohorts.rename(columns={self.icustay_key: "stay_id"}, inplace=True)
763
+
764
+ cohorts.to_csv(os.path.join(self.dest, "data/label.csv"), index=False)
765
+
766
+ # Record corhots df to hdf5
767
+ for _, row in cohorts.iterrows():
768
+ group = ehr_g[str(row["stay_id"])]
769
+ for col in cohorts.columns:
770
+ if isinstance(row[col], (pd.Timestamp, pd.Timedelta)):
771
+ continue
772
+ group.attrs[col] = row[col]
773
+ f.close()
774
+ logger.info("Done encoding events.")
775
+
776
+ return
777
+
778
+ def run_pipeline(self, spark) -> None:
779
+ cohorts = self.build_cohorts(cached=self.cache)
780
+ labeled_cohorts = self.prepare_tasks(cohorts, spark, cached=self.cache)
781
+ events = self.process_tables(labeled_cohorts, spark)
782
+ self.make_input(labeled_cohorts, events, spark)
783
+
784
+ def add_special_tokens(self, new_special_tokens: Union[str, List]) -> None:
785
+ if isinstance(new_special_tokens, str):
786
+ new_special_tokens = [new_special_tokens]
787
+
788
+ num_special_tokens = self.num_special_tokens
789
+ overlapped = []
790
+ for new_special_token in new_special_tokens:
791
+ if new_special_token in self.special_tokens_dict:
792
+ overlapped.append(new_special_token)
793
+
794
+ if len(overlapped) > 0:
795
+ logger.warn(
796
+ "There are some tokens that have already been set to special tokens."
797
+ " Please provide only NEW tokens. Aborted."
798
+ )
799
+ return None
800
+ elif num_special_tokens + len(new_special_tokens) > self.max_special_tokens:
801
+ logger.warn(
802
+ f"Total additional special tokens should be less than {self.max_special_tokens}" " Aborted."
803
+ )
804
+ return None
805
+
806
+ self.special_tokens_dict.update(
807
+ {
808
+ k: "[unused{}]".format(i)
809
+ for i, k in enumerate(new_special_tokens, start=num_special_tokens + 1)
810
+ }
811
+ )
812
+
813
+ def make_compatible(self, icustays):
814
+ """
815
+ make different ehrs compatible with one another here
816
+ NOTE: outtime/dischtime is converted to relative minutes from intime
817
+ but, maintain the intime as the original value for later use
818
+ """
819
+ raise NotImplementedError()
820
+
821
+ def is_compatible(self, icustays):
822
+ checklist = [
823
+ self.hadm_key,
824
+ self.icustay_key,
825
+ self.patient_key,
826
+ "LOS",
827
+ "AGE",
828
+ "INTIME",
829
+ "OUTTIME",
830
+ "DISCHTIME",
831
+ "IN_ICU_MORTALITY",
832
+ "HOS_DISCHARGE_LOCATION",
833
+ ]
834
+ for item in checklist:
835
+ if item not in icustays.columns.to_list():
836
+ return False
837
+ return True
838
+
839
+ def save_to_cache(self, f, fname, use_pickle=False) -> None:
840
+ if use_pickle:
841
+ import pickle
842
+
843
+ with open(os.path.join(self.cache_dir, fname), "wb") as fptr:
844
+ pickle.dump(f, fptr)
845
+ else:
846
+ f.to_pickle(os.path.join(self.cache_dir, fname))
847
+
848
+ def load_from_cache(self, fname):
849
+ cached = os.path.join(self.cache_dir, fname)
850
+ if os.path.exists(cached):
851
+ data = pd.read_pickle(cached)
852
+
853
+ logger.info("Loaded data from {}".format(cached))
854
+ return data
855
+ else:
856
+ return None
857
+
858
+ def infer_data_extension(self) -> str:
859
+ raise NotImplementedError()
860
+
861
+ def download_ehr_from_url(self, url, dest) -> None:
862
+ username = input("Email or Username: ")
863
+ subprocess.run(
864
+ [
865
+ "wget",
866
+ "-r",
867
+ "-N",
868
+ "-c",
869
+ "np",
870
+ "--user",
871
+ username,
872
+ "--ask-password",
873
+ url,
874
+ "-P",
875
+ dest,
876
+ ]
877
+ )
878
+ output_dir = url.replace("https://", "").replace("http://", "")
879
+
880
+ if not os.path.exists(os.path.join(dest, output_dir)):
881
+ raise AssertionError(
882
+ "Download failed. Please check your network connection or "
883
+ "if you log in with a credentialed user"
884
+ )
885
+
886
+ def download_ccs_from_url(self, dest) -> None:
887
+ subprocess.run(
888
+ [
889
+ "wget",
890
+ "-N",
891
+ "-c",
892
+ "https://www.hcup-us.ahrq.gov/toolssoftware/ccs/Multi_Level_CCS_2015.zip",
893
+ "-P",
894
+ dest,
895
+ ]
896
+ )
897
+
898
+ import zipfile
899
+
900
+ with zipfile.ZipFile(os.path.join(dest, "Multi_Level_CCS_2015.zip"), "r") as zip_ref:
901
+ zip_ref.extractall(os.path.join(dest, "foo.d"))
902
+ os.rename(
903
+ os.path.join(dest, "foo.d", "ccs_multi_dx_tool_2015.csv"),
904
+ os.path.join(dest, "ccs_multi_dx_tool_2015.csv"),
905
+ )
906
+ os.remove(os.path.join(dest, "Multi_Level_CCS_2015.zip"))
907
+ shutil.rmtree(os.path.join(dest, "foo.d"))
908
+
909
+ def download_icdgem_from_url(self, dest) -> None:
910
+ subprocess.run(
911
+ [
912
+ "wget",
913
+ "-N",
914
+ "-c",
915
+ "https://data.nber.org/gem/icd10cmtoicd9gem.csv",
916
+ "-P",
917
+ dest,
918
+ ]
919
+ )