genhpf 1.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. genhpf/__init__.py +9 -0
  2. genhpf/configs/__init__.py +23 -0
  3. genhpf/configs/config.yaml +8 -0
  4. genhpf/configs/configs.py +240 -0
  5. genhpf/configs/constants.py +29 -0
  6. genhpf/configs/initialize.py +58 -0
  7. genhpf/configs/utils.py +29 -0
  8. genhpf/criterions/__init__.py +74 -0
  9. genhpf/criterions/binary_cross_entropy.py +114 -0
  10. genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
  11. genhpf/criterions/criterion.py +87 -0
  12. genhpf/criterions/cross_entropy.py +202 -0
  13. genhpf/criterions/multi_task_criterion.py +177 -0
  14. genhpf/criterions/simclr_criterion.py +84 -0
  15. genhpf/criterions/wav2vec2_criterion.py +130 -0
  16. genhpf/datasets/__init__.py +84 -0
  17. genhpf/datasets/dataset.py +109 -0
  18. genhpf/datasets/genhpf_dataset.py +451 -0
  19. genhpf/datasets/meds_dataset.py +232 -0
  20. genhpf/loggings/__init__.py +0 -0
  21. genhpf/loggings/meters.py +374 -0
  22. genhpf/loggings/metrics.py +155 -0
  23. genhpf/loggings/progress_bar.py +445 -0
  24. genhpf/models/__init__.py +73 -0
  25. genhpf/models/genhpf.py +244 -0
  26. genhpf/models/genhpf_mlm.py +64 -0
  27. genhpf/models/genhpf_predictor.py +73 -0
  28. genhpf/models/genhpf_simclr.py +58 -0
  29. genhpf/models/genhpf_wav2vec2.py +304 -0
  30. genhpf/modules/__init__.py +15 -0
  31. genhpf/modules/gather_layer.py +23 -0
  32. genhpf/modules/grad_multiply.py +12 -0
  33. genhpf/modules/gumbel_vector_quantizer.py +204 -0
  34. genhpf/modules/identity_layer.py +8 -0
  35. genhpf/modules/layer_norm.py +27 -0
  36. genhpf/modules/positional_encoding.py +24 -0
  37. genhpf/scripts/__init__.py +0 -0
  38. genhpf/scripts/preprocess/__init__.py +0 -0
  39. genhpf/scripts/preprocess/genhpf/README.md +75 -0
  40. genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
  41. genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
  42. genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
  43. genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
  44. genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
  45. genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
  46. genhpf/scripts/preprocess/genhpf/main.py +175 -0
  47. genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
  48. genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
  49. genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
  50. genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
  51. genhpf/scripts/preprocess/manifest.py +83 -0
  52. genhpf/scripts/preprocess/preprocess_meds.py +674 -0
  53. genhpf/scripts/test.py +264 -0
  54. genhpf/scripts/train.py +365 -0
  55. genhpf/trainer.py +370 -0
  56. genhpf/utils/checkpoint_utils.py +171 -0
  57. genhpf/utils/data_utils.py +130 -0
  58. genhpf/utils/distributed_utils.py +497 -0
  59. genhpf/utils/file_io.py +170 -0
  60. genhpf/utils/pdb.py +38 -0
  61. genhpf/utils/utils.py +204 -0
  62. genhpf-1.0.11.dist-info/LICENSE +21 -0
  63. genhpf-1.0.11.dist-info/METADATA +202 -0
  64. genhpf-1.0.11.dist-info/RECORD +67 -0
  65. genhpf-1.0.11.dist-info/WHEEL +5 -0
  66. genhpf-1.0.11.dist-info/entry_points.txt +6 -0
  67. genhpf-1.0.11.dist-info/top_level.txt +1 -0
@@ -0,0 +1,674 @@
1
+ import functools
2
+ import logging
3
+ import glob
4
+ import multiprocessing
5
+ import os
6
+ import re
7
+ import shutil
8
+ from argparse import ArgumentParser
9
+ import argparse
10
+ from bisect import bisect_left, bisect_right
11
+ from datetime import datetime
12
+ from pathlib import Path
13
+
14
+ import h5py
15
+ import numpy as np
16
+ import pandas as pd
17
+ import polars as pl
18
+ from tqdm import tqdm
19
+ from transformers import AutoTokenizer
20
+
21
+ logger = logging.getLogger(__name__)
22
+ logger.setLevel(logging.INFO)
23
+
24
+ pool_manager = multiprocessing.Manager()
25
+ warned_codes = pool_manager.list()
26
+
27
+
28
+ def find_boundary_between(tuples_list, start, end):
29
+ starts = [s for s, e in tuples_list]
30
+ ends = [e for s, e in tuples_list]
31
+
32
+ start_index = bisect_left(starts, start)
33
+ end_index = bisect_right(ends, end)
34
+ assert start_index < end_index
35
+
36
+ return start_index, end_index
37
+
38
+
39
+ def get_parser():
40
+ parser = ArgumentParser()
41
+ parser.add_argument(
42
+ "root",
43
+ help="path to MEDS dataset. it can be either of directory or the exact file path "
44
+ "with the file extension. if provided with directory, it tries to scan *.csv or "
45
+ "*.parquet files contained in the directory, including sub-directories, to process "
46
+ "all of them.",
47
+ )
48
+ parser.add_argument(
49
+ "--metadata_dir",
50
+ help="path to metadata directory for the input MEDS dataset, which contains codes.parquet",
51
+ )
52
+
53
+ parser.add_argument(
54
+ "--birth_code", type=str, default="MEDS_BIRTH", help="string code for the birth event in the dataset."
55
+ )
56
+
57
+ parser.add_argument(
58
+ "--cohort",
59
+ type=str,
60
+ help="path to the defined cohort, which must be a result of ACES. it can be either of "
61
+ "directory or the exact file path that has the same extension with the MEDS dataset "
62
+ "to be processed. the file structure of this cohort directory should be the same with "
63
+ "the provided MEDS dataset directory to match each cohort to its corresponding shard "
64
+ "data.",
65
+ )
66
+ parser.add_argument(
67
+ "--cohort_label_name",
68
+ type=str,
69
+ default="boolean_value",
70
+ help="column name in the cohort dataframe to be used for label",
71
+ )
72
+ parser.add_argument(
73
+ "--output_dir",
74
+ type=str,
75
+ default="outputs",
76
+ help="directory to save processed outputs.",
77
+ )
78
+ parser.add_argument(
79
+ "--skip-if-exists",
80
+ action="store_true",
81
+ help="whether or not to skip the processing if the output directory already "
82
+ "exists.",
83
+ )
84
+ parser.add_argument(
85
+ "--rebase",
86
+ action="store_true",
87
+ help="whether or not to rebase the output directory if exists.",
88
+ )
89
+ parser.add_argument(
90
+ "--debug",
91
+ type=lambda x: {'true': True, 'false': False}[x.lower()],
92
+ default=False,
93
+ help="whether or not to enable the debug mode, which forces the script to be run with "
94
+ "only one worker."
95
+ )
96
+ parser.add_argument(
97
+ "--workers",
98
+ metavar="N",
99
+ default=1,
100
+ type=int,
101
+ help="number of parallel workers.",
102
+ )
103
+ parser.add_argument(
104
+ "--max_event_length",
105
+ metavar="N",
106
+ default=128,
107
+ type=int,
108
+ help="maximum number of tokens in an event.",
109
+ )
110
+
111
+ parser.add_argument(
112
+ "--mimic_dir",
113
+ default=None,
114
+ help="path to directory for MIMIC-IV database containing hosp/ and icu/ as a subdirectory. "
115
+ "this is used for addressing missing descriptions in the metadata for MIMIC-IV codes.",
116
+ )
117
+
118
+ return parser
119
+
120
+
121
+ def main():
122
+ parser = get_parser()
123
+ args = parser.parse_args()
124
+ root_path = Path(args.root)
125
+ output_dir = Path(args.output_dir)
126
+ metadata_dir = Path(args.metadata_dir)
127
+ mimic_dir = Path(args.mimic_dir) if args.mimic_dir is not None else None
128
+
129
+ num_workers = max(args.workers, 1)
130
+ if args.debug:
131
+ print("debug mode is ON")
132
+ num_workers = 1
133
+ os.environ["RAYON_RS_NUM_CPUS"] = "1"
134
+ else:
135
+ cpu_count = multiprocessing.cpu_count()
136
+ if num_workers > cpu_count:
137
+ logger.warning(
138
+ f"Number of workers (--workers) is greater than the number of available CPUs "
139
+ f"({cpu_count}). Setting the number of workers to {cpu_count}."
140
+ )
141
+ num_workers = cpu_count
142
+
143
+ if root_path.is_dir():
144
+ data_paths = glob.glob(str(root_path / "**/*.csv"), recursive=True)
145
+ if len(data_paths) == 0:
146
+ data_paths = glob.glob(str(root_path / "**/*.parquet"), recursive=True)
147
+ if len(data_paths) == 0:
148
+ raise ValueError("Data directory does not contain any supported file formats: .csv or .parquet")
149
+ else:
150
+ data_paths = [root_path]
151
+
152
+ if not output_dir.exists():
153
+ output_dir.mkdir(parents=True)
154
+ else:
155
+ if args.rebase:
156
+ shutil.rmtree(output_dir)
157
+ output_dir.mkdir(parents=True)
158
+ elif output_dir.exists():
159
+ if args.skip_if_exists:
160
+ ls = glob.glob(str(output_dir / "**/*"), recursive=True)
161
+ expected_files = []
162
+ for subset in set(os.path.dirname(x) for x in data_paths):
163
+ expected_files.extend([
164
+ os.path.join(str(output_dir), os.path.basename(subset), f"{i}.h5")
165
+ for i in range(num_workers)
166
+ ])
167
+ if set(expected_files).issubset(set(ls)):
168
+ logger.info(
169
+ f"Output directory already contains the expected files. Skipping the "
170
+ "processing as --skip-if-exists is set. If you want to rebase the directory, "
171
+ "please run the script with --rebase."
172
+ )
173
+ return
174
+ else:
175
+ raise ValueError(
176
+ f"File exists: '{str(output_dir.resolve())}'. If you want to rebase the "
177
+ "directory automatically, please run the script with --rebase."
178
+ )
179
+
180
+ label_col_name = args.cohort_label_name
181
+
182
+ tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
183
+
184
+ codes_metadata = pl.read_parquet(metadata_dir / "codes.parquet").to_pandas()
185
+ codes_metadata = codes_metadata.set_index("code")["description"].to_dict()
186
+ # do not allow to use static events or birth event
187
+ birth_code = args.birth_code
188
+ # if birth_code not in codes_metadata:
189
+ # print(
190
+ # f'"{birth_code}" is not found in the codes metadata, which may lead to '
191
+ # "unexpected results since we currently exclude this event from the input data. "
192
+ # )
193
+
194
+ if mimic_dir is not None:
195
+ d_items = pd.read_csv(mimic_dir / "icu" / "d_items.csv.gz")
196
+ d_items["itemid"] = d_items["itemid"].astype("str")
197
+ d_items = d_items.set_index("itemid")["label"].to_dict()
198
+ d_labitems = pd.read_csv(mimic_dir / "hosp" / "d_labitems.csv.gz")
199
+ d_labitems["itemid"] = d_labitems["itemid"].astype("str")
200
+ d_labitems = d_labitems.set_index("itemid")["label"].to_dict()
201
+ else:
202
+ d_items = None
203
+ d_labitems = None
204
+
205
+ max_event_length = args.max_event_length
206
+
207
+ progress_bar = tqdm(data_paths, total=len(data_paths))
208
+ for data_path in progress_bar:
209
+ progress_bar.set_description(str(data_path))
210
+
211
+ data_path = Path(data_path)
212
+ subdir = data_path.relative_to(root_path).parent
213
+ if data_path.suffix == ".csv":
214
+ data = pl.scan_csv(
215
+ data_path,
216
+ low_memory=True if args.debug else False,
217
+ )
218
+ elif data_path.suffix == ".parquet":
219
+ data = pl.scan_parquet(
220
+ data_path,
221
+ parallel="none" if args.debug else "auto",
222
+ low_memory=True if args.debug else False,
223
+ )
224
+ else:
225
+ raise ValueError(f"Unsupported file format: {data_path.suffix}")
226
+
227
+ data = data.with_columns(
228
+ pl.when(pl.col("code") == birth_code).then(None).otherwise(pl.col("time")).alias("time")
229
+ )
230
+ data = data.drop_nulls(subset=["subject_id", "time"])
231
+
232
+ cohort_path = Path(args.cohort) / subdir / data_path.name
233
+
234
+ if cohort_path.suffix == ".csv":
235
+ cohort = pl.scan_csv(cohort_path)
236
+ elif cohort_path.suffix == ".parquet":
237
+ cohort = pl.scan_parquet(cohort_path)
238
+ else:
239
+ raise ValueError(f"Unsupported file format: {cohort_path.suffix}")
240
+
241
+ cohort = cohort.drop_nulls(label_col_name)
242
+ cohort = cohort.unique()
243
+
244
+ cohort = cohort.select(
245
+ [
246
+ pl.col("subject_id"),
247
+ pl.col(label_col_name),
248
+ # pl.col("input.end_summary").struct.field("timestamp_at_start").alias("starttime"),
249
+ pl.col("prediction_time").alias("endtime"),
250
+ ]
251
+ )
252
+ cohort = (
253
+ cohort.group_by("subject_id", maintain_order=True)
254
+ .agg(pl.col(["endtime", label_col_name]))
255
+ .collect()
256
+ ) # omitted "starttime"
257
+ cohort_dict = {
258
+ x["subject_id"]: {
259
+ # "starttime": x["starttime"],
260
+ "endtime": x["endtime"],
261
+ "label": x[label_col_name],
262
+ }
263
+ for x in cohort.iter_rows(named=True)
264
+ }
265
+
266
+ def extract_cohort(row):
267
+ subject_id = row["subject_id"]
268
+ time = row["time"]
269
+ if subject_id not in cohort_dict:
270
+ # return {"cohort_start": None, "cohort_end": None, "cohort_label": None}
271
+ return {"cohort_end": None, "cohort_label": None}
272
+
273
+ cohort_criteria = cohort_dict[subject_id]
274
+ # starts = cohort_criteria["starttime"]
275
+ ends = cohort_criteria["endtime"]
276
+ labels = cohort_criteria["label"]
277
+
278
+ # for start, end, label in zip(starts, ends, labels):
279
+ # if start <= time and time <= end:
280
+ # return {"cohort_start": start, "cohort_end": end, "cohort_label": label}
281
+
282
+ # assume it is possible that each event goes into multiple different cohorts
283
+ cohort_ends = []
284
+ cohort_labels = []
285
+ for end, label in zip(ends, labels):
286
+ if time <= end:
287
+ # return {"cohort_start": start, "cohort_end": end, "cohort_label": label}
288
+ cohort_ends.append(end)
289
+ cohort_labels.append(label)
290
+
291
+ if len(cohort_ends) > 0:
292
+ return {"cohort_end": cohort_ends, "cohort_label": cohort_labels}
293
+ else:
294
+ # return {"cohort_start": None, "cohort_end": None, "cohort_label": None}
295
+ return {"cohort_end": None, "cohort_label": None}
296
+
297
+ data = data.group_by(["subject_id", "time"], maintain_order=True).agg(pl.all())
298
+ data = (
299
+ data.with_columns(
300
+ pl.struct(["subject_id", "time"])
301
+ .map_elements(
302
+ extract_cohort,
303
+ return_dtype=pl.Struct(
304
+ {
305
+ "cohort_end": pl.List(pl.Datetime()),
306
+ "cohort_label": pl.List(pl.Boolean),
307
+ }
308
+ ),
309
+ )
310
+ .alias("cohort_criteria")
311
+ )
312
+ .unnest("cohort_criteria")
313
+ .collect()
314
+ )
315
+
316
+ data = data.drop_nulls("cohort_label")
317
+
318
+ data = data.with_columns(pl.col("time").dt.strftime("%Y-%m-%d %H:%M:%S").cast(pl.List(str)))
319
+ data = data.with_columns(
320
+ pl.col("time").list.sample(n=pl.col("code").list.len(), with_replacement=True)
321
+ )
322
+ if args.debug:
323
+ data = data[:5000]
324
+
325
+ if str(subdir) != ".":
326
+ output_name = str(subdir)
327
+ else:
328
+ output_name = data_path.stem
329
+
330
+ if not os.path.exists(output_dir / output_name):
331
+ os.makedirs(output_dir / output_name)
332
+
333
+ with open(str(output_dir / (output_name + ".tsv")), "a") as manifest_f:
334
+ if os.path.getsize(output_dir / (output_name + ".tsv")) == 0:
335
+ manifest_f.write(f"{output_dir}/{output_name}\n")
336
+
337
+ must_have_columns = [
338
+ "subject_id",
339
+ "cohort_end",
340
+ "cohort_label",
341
+ "time",
342
+ "code",
343
+ "numeric_value",
344
+ ]
345
+ rest_of_columns = [x for x in data.columns if x not in must_have_columns]
346
+ column_name_idcs = {col: i for i, col in enumerate(data.columns)}
347
+
348
+ meds_to_remed_partial = functools.partial(
349
+ meds_to_remed,
350
+ tokenizer,
351
+ rest_of_columns,
352
+ column_name_idcs,
353
+ codes_metadata,
354
+ output_dir,
355
+ output_name,
356
+ num_workers,
357
+ d_items,
358
+ d_labitems,
359
+ warned_codes,
360
+ max_event_length,
361
+ args.debug,
362
+ )
363
+
364
+ # meds --> remed
365
+ logger.info(f"Start processing {data_path}")
366
+ if num_workers <= 1:
367
+ length_per_subject_gathered = [meds_to_remed_partial(data)]
368
+ del data
369
+ else:
370
+ subject_ids = data["subject_id"].unique().to_list()
371
+ n = num_workers
372
+ subject_id_chunks = [subject_ids[i::n] for i in range(n)]
373
+ data_chunks = []
374
+ for subject_id_chunk in subject_id_chunks:
375
+ data_chunks.append(data.filter(pl.col("subject_id").is_in(subject_id_chunk)))
376
+ del data
377
+
378
+ num_valid_data_chunks = sum(map(lambda x: len(x) > 0, data_chunks))
379
+ if num_valid_data_chunks < num_workers:
380
+ raise ValueError(
381
+ "Number of valid data chunks (= number of unique subjects) were smaller "
382
+ "than the specified num workers (--workers) due to the small size of data. "
383
+ "Consider reducing the number of workers."
384
+ )
385
+
386
+ pool = multiprocessing.get_context("spawn").Pool(processes=num_workers)
387
+ # the order is preserved
388
+ length_per_subject_gathered = pool.map(meds_to_remed_partial, data_chunks)
389
+ pool.close()
390
+ pool.join()
391
+ del data_chunks
392
+
393
+ if len(length_per_subject_gathered) != num_workers:
394
+ raise ValueError(
395
+ "Number of processed workers were smaller than the specified num workers "
396
+ "(--workers) due to the small size of data. Consider reducing the number of "
397
+ "workers."
398
+ )
399
+
400
+ for length_per_subject in length_per_subject_gathered:
401
+ for subject_id, (length, shard_id) in length_per_subject.items():
402
+ manifest_f.write(f"{subject_id}\t{length}\t{shard_id}\n")
403
+
404
+
405
+ def meds_to_remed(
406
+ tokenizer,
407
+ rest_of_columns,
408
+ column_name_idcs,
409
+ codes_metadata,
410
+ output_dir,
411
+ output_name,
412
+ num_shards,
413
+ d_items,
414
+ d_labitems,
415
+ warned_codes,
416
+ max_event_length,
417
+ debug,
418
+ df_chunk,
419
+ ):
420
+ code_matching_pattern = re.compile(r"\d+")
421
+
422
+ def meds_to_remed_unit(row):
423
+ events = []
424
+ digit_offsets = []
425
+ col_name_offsets = []
426
+ for event_index in range(len(row[column_name_idcs["code"]])):
427
+ event = ""
428
+ digit_offset = []
429
+ col_name_offset = []
430
+ for col_name in ["code", "numeric_value"] + rest_of_columns:
431
+ # do not process something like "icustay_id" or "hadm_id"
432
+ if "id" in col_name:
433
+ continue
434
+
435
+ col_event = row[column_name_idcs[col_name]][event_index]
436
+ # print(f"col_name: {col_name}")
437
+ # print(row[column_name_idcs[col_name]][event_index])
438
+ if col_event is not None:
439
+ col_event = str(col_event)
440
+ # if col_name == "code":
441
+ # if col_event in codes_metadata and codes_metadata[col_event] != "":
442
+ # col_event = codes_metadata[col_event]
443
+ # else:
444
+ # do_break = False
445
+ # safely resolve description: check for None or empty string
446
+ if col_name == "code":
447
+ desc = codes_metadata.get(col_event)
448
+ if desc is not None and desc != "":
449
+ col_event = desc
450
+ else:
451
+ do_break = False
452
+ items = col_event.split("//")
453
+ is_code = [bool(code_matching_pattern.fullmatch(item)) for item in items]
454
+ if True in is_code:
455
+ if d_items is not None and d_labitems is not None:
456
+ code_idx = is_code.index(True)
457
+ code = items[code_idx]
458
+
459
+ if code in d_items:
460
+ desc = d_items[code]
461
+ elif code in d_labitems:
462
+ desc = d_labitems[code]
463
+ else:
464
+ do_break = True
465
+
466
+ if not do_break:
467
+ items[code_idx] = desc
468
+ col_event = "//".join(items)
469
+ else:
470
+ do_break = True
471
+ if do_break and col_event not in warned_codes:
472
+ warned_codes.append(col_event)
473
+ logger.warning(
474
+ "The dataset contains some codes that are not specified in "
475
+ "the codes metadata, which may not be intended. Note that we "
476
+ f"process this code as it is for now: {col_event}."
477
+ )
478
+ else:
479
+ col_event = re.sub(
480
+ r"\d*\.\d+",
481
+ lambda x: str(round(float(x.group(0)), 4)),
482
+ col_event,
483
+ )
484
+ event_offset = len(event) + len(col_name) + 1
485
+ digit_offset_tmp = [
486
+ g.span() for g in re.finditer(r"([0-9]+([.][0-9]*)?|[0-9]+|\.+)", col_event)
487
+ ]
488
+
489
+ internal_offset = 0
490
+ for start, end in digit_offset_tmp:
491
+ digit_offset.append(
492
+ (
493
+ event_offset + start + internal_offset,
494
+ event_offset + end + (end - start) * 2 + internal_offset,
495
+ )
496
+ )
497
+ internal_offset += (end - start) * 2
498
+
499
+ col_event = re.sub(r"([0-9\.])", r" \1 ", col_event)
500
+ if col_event is None:
501
+ logger.warning(f"Skipped col_event for col_name {col_name} because it is None")
502
+ continue
503
+ col_name_offset.append((len(event), len(event) + len(col_name)))
504
+ event += " " + col_name + " " + col_event
505
+ if len(event) > 0:
506
+ events.append(event[1:])
507
+ digit_offsets.append(digit_offset)
508
+ col_name_offsets.append(col_name_offset)
509
+
510
+ tokenized_events = tokenizer(
511
+ events,
512
+ add_special_tokens=True,
513
+ padding="max_length",
514
+ max_length=max_event_length,
515
+ truncation=True,
516
+ return_tensors="np",
517
+ return_token_type_ids=False,
518
+ return_attention_mask=True,
519
+ return_offsets_mapping=True,
520
+ )
521
+ lengths_before_padding = tokenized_events["attention_mask"].sum(axis=1)
522
+
523
+ input_ids = tokenized_events["input_ids"]
524
+ dpe_ids = np.zeros(input_ids.shape, dtype=int)
525
+ for i, digit_offset in enumerate(digit_offsets):
526
+ for start, end in digit_offset:
527
+ start_index, end_index = find_boundary_between(
528
+ tokenized_events[i].offsets[: lengths_before_padding[i] - 1],
529
+ start,
530
+ end,
531
+ )
532
+
533
+ # define dpe ids for digits found
534
+ num_digits = end_index - start_index
535
+ # 119: token id for "."
536
+ num_decimal_points = (input_ids[i][start_index:end_index] == 119).sum()
537
+
538
+ # integer without decimal point
539
+ # e.g., for "1 2 3 4 5", assign [10, 9, 8, 7, 6]
540
+ if num_decimal_points == 0:
541
+ dpe_ids[i][start_index:end_index] = list(range(num_digits + 5, 5, -1))
542
+ # floats
543
+ # e.g., for "1 2 3 4 5 . 6 7 8 9", assign [10, 9, 8, 7, 6, 5, 4, 3, 2, 1]
544
+ elif num_decimal_points == 1:
545
+ num_decimals = (
546
+ num_digits
547
+ - np.where(input_ids[i][start_index:end_index] == 119)[0][0] # 119: token id for "."
548
+ )
549
+ dpe_ids[i][start_index:end_index] = list(
550
+ range(num_digits + 5 - num_decimals, 5 - num_decimals, -1)
551
+ )
552
+ # 1 > decimal points where we cannot define dpe ids
553
+ else:
554
+ continue
555
+ # define type ids
556
+ # for column names: 2
557
+ # for column values (contents): 3
558
+ # for CLS tokens: 5
559
+ # for SEP tokens: 6
560
+ type_ids = np.zeros(input_ids.shape, dtype=int)
561
+ type_ids[:, 0] = 5 # CLS tokens
562
+ for i, col_name_offset in enumerate(col_name_offsets):
563
+ type_ids[i][lengths_before_padding[i] - 1] = 6 # SEP tokens
564
+ # fill with type ids for column values
565
+ type_ids[i][1 : lengths_before_padding[i] - 1] = 3
566
+ for start, end in col_name_offset:
567
+ start_index, end_index = find_boundary_between(
568
+ tokenized_events[i].offsets[1 : lengths_before_padding[i] - 1],
569
+ start,
570
+ end,
571
+ )
572
+ # the first offset is always (0, 0) for CLS token, so we adjust it
573
+ start_index += 1
574
+ end_index += 1
575
+ # finally replace with type ids for column names
576
+ type_ids[i][start_index:end_index] = 2
577
+
578
+ return np.stack([input_ids, type_ids, dpe_ids], axis=1).astype(np.uint16)
579
+
580
+ events_data = []
581
+ worker_id = multiprocessing.current_process().name.split("-")[-1]
582
+ if worker_id == "MainProcess":
583
+ worker_id = 0
584
+ else:
585
+ # worker_id is incremental for every generated pool, so divide with num_shards
586
+ worker_id = (int(worker_id) - 1) % num_shards # 1-based -> 0-based indexing
587
+ if worker_id == 0:
588
+ progress_bar = tqdm(df_chunk.iter_rows(), total=len(df_chunk))
589
+ progress_bar.set_description(f"Processing from worker-{worker_id}:")
590
+ else:
591
+ progress_bar = df_chunk.iter_rows()
592
+
593
+ for row in progress_bar:
594
+ events_data.append(meds_to_remed_unit(row))
595
+ data_length = list(map(len, events_data))
596
+ data_index_offset = np.zeros(len(data_length), dtype=np.int64)
597
+ data_index_offset[1:] = np.cumsum(data_length[:-1])
598
+ data_index = pl.Series(
599
+ "data_index",
600
+ map(
601
+ lambda x: [data_index_offset[x] + y for y in range(data_length[x])],
602
+ range(len(data_length)),
603
+ ),
604
+ )
605
+ events_data = np.concatenate(events_data)
606
+
607
+ df_chunk = df_chunk.select(["subject_id", "cohort_end", "cohort_label", "time"])
608
+ df_chunk = df_chunk.insert_column(4, data_index)
609
+ df_chunk = df_chunk.explode(["cohort_end", "cohort_label"])
610
+ df_chunk = df_chunk.group_by(
611
+ # ["subject_id", "cohort_start", "cohort_end", "cohort_label"], maintain_order=True
612
+ ["subject_id", "cohort_end", "cohort_label"],
613
+ maintain_order=True,
614
+ ).agg(pl.all())
615
+
616
+ if debug:
617
+ print("debug_mode is on!")
618
+ df_chunk = df_chunk.with_columns(
619
+ [
620
+ pl.col("time").map_elements(lambda x: x[-100:], return_dtype=pl.List(pl.List(str))),
621
+ pl.col("data_index").map_elements(lambda x: x[-100:], return_dtype=pl.List(pl.List(int)))
622
+ ]
623
+ )
624
+
625
+ df_chunk = df_chunk.sort(by=["subject_id", "cohort_end"])
626
+ # regard {subject_id} as {cohort_id}: {subject_id}_{cohort_number}
627
+ df_chunk = df_chunk.with_columns(pl.col("subject_id").cum_count().over("subject_id").alias("suffix"))
628
+ df_chunk = df_chunk.with_columns(
629
+ (pl.col("subject_id").cast(str) + "_" + pl.col("suffix").cast(str)).alias("subject_id")
630
+ )
631
+ # data = data.drop("suffix", "cohort_start", "cohort_end")
632
+ df_chunk = df_chunk.drop("suffix", "cohort_end")
633
+
634
+ length_per_subject = {}
635
+ progress_bar = tqdm(
636
+ df_chunk.iter_rows(),
637
+ total=len(df_chunk),
638
+ desc=f"Writing data from worker-{worker_id}:",
639
+ )
640
+
641
+ for sample in progress_bar:
642
+ with h5py.File(str(output_dir / output_name / f"{worker_id}.h5"), "a") as f:
643
+ if "ehr" in f:
644
+ result = f["ehr"]
645
+ else:
646
+ result = f.create_group("ehr")
647
+
648
+ sample_result = result.create_group(sample[0])
649
+
650
+ times = np.concatenate(sample[2])
651
+ data_indices = np.concatenate(sample[3])
652
+ if debug:
653
+ data_indices = data_indices[-100:]
654
+ times = times[-100:]
655
+
656
+ data = events_data[data_indices]
657
+ sample_result.create_dataset("hi", data=data, dtype="i2", compression="lzf", shuffle=True)
658
+
659
+ times = [datetime.strptime(x, "%Y-%m-%d %H:%M:%S") for x in times]
660
+ times = np.cumsum(np.diff(times))
661
+ times = list(map(lambda x: round(x.total_seconds() / 60), times))
662
+ times = np.array([0] + times)
663
+
664
+ sample_result.create_dataset("time", data=times, dtype="i")
665
+ sample_result.create_dataset("label", data=int(sample[1]))
666
+
667
+ length_per_subject[sample[0]] = (len(times), worker_id)
668
+ del df_chunk
669
+
670
+ return length_per_subject
671
+
672
+
673
+ if __name__ == "__main__":
674
+ main()