genhpf 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of genhpf might be problematic. Click here for more details.

Files changed (67) hide show
  1. genhpf/__init__.py +9 -0
  2. genhpf/configs/__init__.py +23 -0
  3. genhpf/configs/config.yaml +8 -0
  4. genhpf/configs/configs.py +240 -0
  5. genhpf/configs/constants.py +29 -0
  6. genhpf/configs/initialize.py +58 -0
  7. genhpf/configs/utils.py +29 -0
  8. genhpf/criterions/__init__.py +74 -0
  9. genhpf/criterions/binary_cross_entropy.py +114 -0
  10. genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
  11. genhpf/criterions/criterion.py +87 -0
  12. genhpf/criterions/cross_entropy.py +202 -0
  13. genhpf/criterions/multi_task_criterion.py +177 -0
  14. genhpf/criterions/simclr_criterion.py +84 -0
  15. genhpf/criterions/wav2vec2_criterion.py +130 -0
  16. genhpf/datasets/__init__.py +84 -0
  17. genhpf/datasets/dataset.py +109 -0
  18. genhpf/datasets/genhpf_dataset.py +451 -0
  19. genhpf/datasets/meds_dataset.py +232 -0
  20. genhpf/loggings/__init__.py +0 -0
  21. genhpf/loggings/meters.py +374 -0
  22. genhpf/loggings/metrics.py +155 -0
  23. genhpf/loggings/progress_bar.py +445 -0
  24. genhpf/models/__init__.py +73 -0
  25. genhpf/models/genhpf.py +233 -0
  26. genhpf/models/genhpf_mlm.py +64 -0
  27. genhpf/models/genhpf_predictor.py +73 -0
  28. genhpf/models/genhpf_simclr.py +58 -0
  29. genhpf/models/genhpf_wav2vec2.py +304 -0
  30. genhpf/modules/__init__.py +15 -0
  31. genhpf/modules/gather_layer.py +23 -0
  32. genhpf/modules/grad_multiply.py +12 -0
  33. genhpf/modules/gumbel_vector_quantizer.py +204 -0
  34. genhpf/modules/identity_layer.py +8 -0
  35. genhpf/modules/layer_norm.py +27 -0
  36. genhpf/modules/positional_encoding.py +24 -0
  37. genhpf/scripts/__init__.py +0 -0
  38. genhpf/scripts/preprocess/__init__.py +0 -0
  39. genhpf/scripts/preprocess/genhpf/README.md +75 -0
  40. genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
  41. genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
  42. genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
  43. genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
  44. genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
  45. genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
  46. genhpf/scripts/preprocess/genhpf/main.py +174 -0
  47. genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
  48. genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
  49. genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
  50. genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
  51. genhpf/scripts/preprocess/manifest.py +83 -0
  52. genhpf/scripts/preprocess/preprocess_meds.py +584 -0
  53. genhpf/scripts/test.py +261 -0
  54. genhpf/scripts/train.py +350 -0
  55. genhpf/trainer.py +370 -0
  56. genhpf/utils/checkpoint_utils.py +171 -0
  57. genhpf/utils/data_utils.py +130 -0
  58. genhpf/utils/distributed_utils.py +497 -0
  59. genhpf/utils/file_io.py +170 -0
  60. genhpf/utils/pdb.py +38 -0
  61. genhpf/utils/utils.py +204 -0
  62. genhpf-1.0.0.dist-info/LICENSE +21 -0
  63. genhpf-1.0.0.dist-info/METADATA +197 -0
  64. genhpf-1.0.0.dist-info/RECORD +67 -0
  65. genhpf-1.0.0.dist-info/WHEEL +5 -0
  66. genhpf-1.0.0.dist-info/entry_points.txt +6 -0
  67. genhpf-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,584 @@
1
+ import functools
2
+ import glob
3
+ import multiprocessing
4
+ import os
5
+ import re
6
+ import shutil
7
+ import warnings
8
+ from argparse import ArgumentParser
9
+ from bisect import bisect_left, bisect_right
10
+ from datetime import datetime
11
+ from pathlib import Path
12
+
13
+ import h5py
14
+ import numpy as np
15
+ import pandas as pd
16
+ import polars as pl
17
+ from tqdm import tqdm
18
+ from transformers import AutoTokenizer
19
+
20
+ pool_manager = multiprocessing.Manager()
21
+ warned_codes = pool_manager.list()
22
+
23
+
24
+ def find_boundary_between(tuples_list, start, end):
25
+ starts = [s for s, e in tuples_list]
26
+ ends = [e for s, e in tuples_list]
27
+
28
+ start_index = bisect_left(starts, start)
29
+ end_index = bisect_right(ends, end)
30
+ assert start_index < end_index
31
+
32
+ return start_index, end_index
33
+
34
+
35
+ def get_parser():
36
+ parser = ArgumentParser()
37
+ parser.add_argument(
38
+ "root",
39
+ help="path to MEDS dataset. it can be either of directory or the exact file path "
40
+ "with the file extension. if provided with directory, it tries to scan *.csv or "
41
+ "*.parquet files contained in the directory, including sub-directories, to process "
42
+ "all of them.",
43
+ )
44
+ parser.add_argument(
45
+ "--metadata_dir",
46
+ help="path to metadata directory for the input MEDS dataset, which contains codes.parquet",
47
+ )
48
+
49
+ parser.add_argument(
50
+ "--birth_code", type=str, default="MEDS_BIRTH", help="string code for the birth event in the dataset."
51
+ )
52
+
53
+ parser.add_argument(
54
+ "--cohort",
55
+ type=str,
56
+ help="path to the defined cohort, which must be a result of ACES. it can be either of "
57
+ "directory or the exact file path that has the same extension with the MEDS dataset "
58
+ "to be processed. the file structure of this cohort directory should be the same with "
59
+ "the provided MEDS dataset directory to match each cohort to its corresponding shard "
60
+ "data.",
61
+ )
62
+ parser.add_argument(
63
+ "--cohort_label_name",
64
+ type=str,
65
+ default="boolean_value",
66
+ help="column name in the cohort dataframe to be used for label",
67
+ )
68
+ parser.add_argument(
69
+ "--output_dir",
70
+ type=str,
71
+ default="outputs",
72
+ help="directory to save processed outputs.",
73
+ )
74
+ parser.add_argument(
75
+ "--rebase",
76
+ action="store_true",
77
+ help="whether or not to rebase the output directory if exists.",
78
+ )
79
+ parser.add_argument(
80
+ "--workers",
81
+ metavar="N",
82
+ default=1,
83
+ type=int,
84
+ help="number of parallel workers.",
85
+ )
86
+ parser.add_argument(
87
+ "--max_event_length",
88
+ metavar="N",
89
+ default=128,
90
+ type=int,
91
+ help="maximum number of tokens in an event.",
92
+ )
93
+
94
+ parser.add_argument(
95
+ "--mimic_dir",
96
+ default=None,
97
+ help="path to directory for MIMIC-IV database containing hosp/ and icu/ as a subdirectory. "
98
+ "this is used for addressing missing descriptions in the metadata for MIMIC-IV codes.",
99
+ )
100
+
101
+ return parser
102
+
103
+
104
+ def main(args):
105
+ root_path = Path(args.root)
106
+ output_dir = Path(args.output_dir)
107
+ metadata_dir = Path(args.metadata_dir)
108
+ mimic_dir = Path(args.mimic_dir) if args.mimic_dir is not None else None
109
+
110
+ if not output_dir.exists():
111
+ output_dir.mkdir()
112
+ else:
113
+ if args.rebase:
114
+ shutil.rmtree(output_dir)
115
+ if output_dir.exists():
116
+ raise ValueError(
117
+ f"File exists: '{str(output_dir.resolve())}'. If you want to rebase the "
118
+ "directory, please run the script with --rebase."
119
+ )
120
+ output_dir.mkdir()
121
+
122
+ if root_path.is_dir():
123
+ data_paths = glob.glob(str(root_path / "**/*.csv"), recursive=True)
124
+ if len(data_paths) == 0:
125
+ data_paths = glob.glob(str(root_path / "**/*.parquet"), recursive=True)
126
+ if len(data_paths) == 0:
127
+ raise ValueError("Data directory does not contain any supported file formats: .csv or .parquet")
128
+ else:
129
+ data_paths = [root_path]
130
+
131
+ label_col_name = args.cohort_label_name
132
+
133
+ tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
134
+
135
+ codes_metadata = pl.read_parquet(metadata_dir / "codes.parquet").to_pandas()
136
+ codes_metadata = codes_metadata.set_index("code")["description"].to_dict()
137
+ # do not allow to use static events or birth event
138
+ birth_code = args.birth_code
139
+ # if birth_code not in codes_metadata:
140
+ # print(
141
+ # f'"{birth_code}" is not found in the codes metadata, which may lead to '
142
+ # "unexpected results since we currently exclude this event from the input data. "
143
+ # )
144
+
145
+ if mimic_dir is not None:
146
+ d_items = pd.read_csv(mimic_dir / "icu" / "d_items.csv.gz")
147
+ d_items["itemid"] = d_items["itemid"].astype("str")
148
+ d_items = d_items.set_index("itemid")["label"].to_dict()
149
+ d_labitems = pd.read_csv(mimic_dir / "hosp" / "d_labitems.csv.gz")
150
+ d_labitems["itemid"] = d_labitems["itemid"].astype("str")
151
+ d_labitems = d_labitems.set_index("itemid")["label"].to_dict()
152
+ else:
153
+ d_items = None
154
+ d_labitems = None
155
+
156
+ max_event_length = args.max_event_length
157
+
158
+ progress_bar = tqdm(data_paths, total=len(data_paths))
159
+ for data_path in progress_bar:
160
+ progress_bar.set_description(str(data_path))
161
+
162
+ data_path = Path(data_path)
163
+ subdir = data_path.relative_to(root_path).parent
164
+ if data_path.suffix == ".csv":
165
+ data = pl.scan_csv(data_path)
166
+ elif data_path.suffix == ".parquet":
167
+ data = pl.scan_parquet(data_path)
168
+ else:
169
+ raise ValueError(f"Unsupported file format: {data_path.suffix}")
170
+
171
+ data = data.with_columns(
172
+ pl.when(pl.col("code") == birth_code).then(None).otherwise(pl.col("time")).alias("time")
173
+ )
174
+ data = data.drop_nulls(subset=["subject_id", "time"])
175
+
176
+ cohort_path = Path(args.cohort) / subdir / data_path.name
177
+
178
+ if cohort_path.suffix == ".csv":
179
+ cohort = pl.scan_csv(cohort_path)
180
+ elif cohort_path.suffix == ".parquet":
181
+ cohort = pl.scan_parquet(cohort_path)
182
+ else:
183
+ raise ValueError(f"Unsupported file format: {cohort_path.suffix}")
184
+
185
+ cohort = cohort.drop_nulls(label_col_name)
186
+ cohort = cohort.unique()
187
+
188
+ cohort = cohort.select(
189
+ [
190
+ pl.col("subject_id"),
191
+ pl.col(label_col_name),
192
+ # pl.col("input.end_summary").struct.field("timestamp_at_start").alias("starttime"),
193
+ pl.col("prediction_time").alias("endtime"),
194
+ ]
195
+ )
196
+ cohort = (
197
+ cohort.group_by("subject_id", maintain_order=True)
198
+ .agg(pl.col(["endtime", label_col_name]))
199
+ .collect()
200
+ ) # omitted "starttime"
201
+ cohort_dict = {
202
+ x["subject_id"]: {
203
+ # "starttime": x["starttime"],
204
+ "endtime": x["endtime"],
205
+ "label": x[label_col_name],
206
+ }
207
+ for x in cohort.iter_rows(named=True)
208
+ }
209
+
210
+ def extract_cohort(row):
211
+ subject_id = row["subject_id"]
212
+ time = row["time"]
213
+ if subject_id not in cohort_dict:
214
+ # return {"cohort_start": None, "cohort_end": None, "cohort_label": None}
215
+ return {"cohort_end": None, "cohort_label": None}
216
+
217
+ cohort_criteria = cohort_dict[subject_id]
218
+ # starts = cohort_criteria["starttime"]
219
+ ends = cohort_criteria["endtime"]
220
+ labels = cohort_criteria["label"]
221
+
222
+ # for start, end, label in zip(starts, ends, labels):
223
+ # if start <= time and time <= end:
224
+ # return {"cohort_start": start, "cohort_end": end, "cohort_label": label}
225
+
226
+ # assume it is possible that each event goes into multiple different cohorts
227
+ cohort_ends = []
228
+ cohort_labels = []
229
+ for end, label in zip(ends, labels):
230
+ if time <= end:
231
+ # return {"cohort_start": start, "cohort_end": end, "cohort_label": label}
232
+ cohort_ends.append(end)
233
+ cohort_labels.append(label)
234
+
235
+ if len(cohort_ends) > 0:
236
+ return {"cohort_end": cohort_ends, "cohort_label": cohort_labels}
237
+ else:
238
+ # return {"cohort_start": None, "cohort_end": None, "cohort_label": None}
239
+ return {"cohort_end": None, "cohort_label": None}
240
+
241
+ data = data.group_by(["subject_id", "time"], maintain_order=True).agg(pl.all())
242
+ data = (
243
+ data.with_columns(
244
+ pl.struct(["subject_id", "time"])
245
+ .map_elements(
246
+ extract_cohort,
247
+ return_dtype=pl.Struct(
248
+ {
249
+ "cohort_end": pl.List(pl.Datetime()),
250
+ "cohort_label": pl.List(pl.Boolean),
251
+ }
252
+ ),
253
+ )
254
+ .alias("cohort_criteria")
255
+ )
256
+ .unnest("cohort_criteria")
257
+ .collect()
258
+ )
259
+
260
+ data = data.drop_nulls("cohort_label")
261
+
262
+ data = data.with_columns(pl.col("time").dt.strftime("%Y-%m-%d %H:%M:%S").cast(pl.List(str)))
263
+ data = data.with_columns(
264
+ pl.col("time").list.sample(n=pl.col("code").list.len(), with_replacement=True)
265
+ )
266
+
267
+ if str(subdir) != ".":
268
+ output_name = str(subdir)
269
+ else:
270
+ output_name = data_path.stem
271
+
272
+ if not os.path.exists(output_dir / output_name):
273
+ os.makedirs(output_dir / output_name)
274
+
275
+ with open(str(output_dir / (output_name + ".tsv")), "a") as manifest_f:
276
+ if os.path.getsize(output_dir / (output_name + ".tsv")) == 0:
277
+ manifest_f.write(f"{output_dir}/{output_name}\n")
278
+
279
+ must_have_columns = [
280
+ "subject_id",
281
+ "cohort_end",
282
+ "cohort_label",
283
+ "time",
284
+ "code",
285
+ "numeric_value",
286
+ ]
287
+ rest_of_columns = [x for x in data.columns if x not in must_have_columns]
288
+ column_name_idcs = {col: i for i, col in enumerate(data.columns)}
289
+
290
+ meds_to_remed_partial = functools.partial(
291
+ meds_to_remed,
292
+ tokenizer,
293
+ rest_of_columns,
294
+ column_name_idcs,
295
+ codes_metadata,
296
+ output_dir,
297
+ output_name,
298
+ args.workers,
299
+ d_items,
300
+ d_labitems,
301
+ warned_codes,
302
+ max_event_length,
303
+ )
304
+
305
+ # meds --> remed
306
+ print("Processing...")
307
+ if args.workers <= 1:
308
+ length_per_subject_gathered = [meds_to_remed_partial(data)]
309
+ del data
310
+ else:
311
+ subject_ids = data["subject_id"].unique().to_list()
312
+ n = args.workers
313
+ subject_id_chunks = [subject_ids[i::n] for i in range(n)]
314
+ data_chunks = []
315
+ for subject_id_chunk in subject_id_chunks:
316
+ data_chunks.append(data.filter(pl.col("subject_id").is_in(subject_id_chunk)))
317
+ del data
318
+ pool = multiprocessing.get_context("spawn").Pool(processes=args.workers)
319
+ # the order is preserved
320
+ length_per_subject_gathered = pool.map(meds_to_remed_partial, data_chunks)
321
+ pool.close()
322
+ pool.join()
323
+ del data_chunks
324
+
325
+ if len(length_per_subject_gathered) != args.workers:
326
+ print(
327
+ "Number of processed workers were smaller than the specified num workers "
328
+ "(--workers) due to the small size of data. Consider reducing the number of "
329
+ "workers."
330
+ )
331
+
332
+ for length_per_subject in length_per_subject_gathered:
333
+ for subject_id, (length, shard_id) in length_per_subject.items():
334
+ manifest_f.write(f"{subject_id}\t{length}\t{shard_id}\n")
335
+
336
+
337
+ def meds_to_remed(
338
+ tokenizer,
339
+ rest_of_columns,
340
+ column_name_idcs,
341
+ codes_metadata,
342
+ output_dir,
343
+ output_name,
344
+ num_shards,
345
+ d_items,
346
+ d_labitems,
347
+ warned_codes,
348
+ max_event_length,
349
+ df_chunk,
350
+ ):
351
+ code_matching_pattern = re.compile(r"\d+")
352
+
353
+ def meds_to_remed_unit(row):
354
+ events = []
355
+ digit_offsets = []
356
+ col_name_offsets = []
357
+ for event_index in range(len(row[column_name_idcs["code"]])):
358
+ event = ""
359
+ digit_offset = []
360
+ col_name_offset = []
361
+ for col_name in ["code", "numeric_value"] + rest_of_columns:
362
+ # do not process something like "icustay_id" or "hadm_id"
363
+ if "id" in col_name:
364
+ continue
365
+
366
+ col_event = row[column_name_idcs[col_name]][event_index]
367
+ if col_event is not None:
368
+ col_event = str(col_event)
369
+ if col_name == "code":
370
+ if col_event in codes_metadata and codes_metadata[col_event] != "":
371
+ col_event = codes_metadata[col_event]
372
+ else:
373
+ do_break = False
374
+ items = col_event.split("//")
375
+ is_code = [bool(code_matching_pattern.fullmatch(item)) for item in items]
376
+ if True in is_code:
377
+ if d_items is not None and d_labitems is not None:
378
+ code_idx = is_code.index(True)
379
+ code = items[code_idx]
380
+
381
+ if code in d_items:
382
+ desc = d_items[code]
383
+ elif code in d_labitems:
384
+ desc = d_labitems[code]
385
+ else:
386
+ do_break = True
387
+
388
+ if not do_break:
389
+ items[code_idx] = desc
390
+ col_event = "//".join(items)
391
+ else:
392
+ do_break = True
393
+
394
+ if do_break and col_event not in warned_codes:
395
+ warned_codes.append(col_event)
396
+ warnings.warn(
397
+ "The dataset contains some codes that are not specified in "
398
+ "the codes metadata, which may not be intended. Note that we "
399
+ f"process this code as it is for now: {col_event}."
400
+ )
401
+ else:
402
+ col_event = re.sub(
403
+ r"\d*\.\d+",
404
+ lambda x: str(round(float(x.group(0)), 4)),
405
+ col_event,
406
+ )
407
+ event_offset = len(event) + len(col_name) + 1
408
+ digit_offset_tmp = [
409
+ g.span() for g in re.finditer(r"([0-9]+([.][0-9]*)?|[0-9]+|\.+)", col_event)
410
+ ]
411
+
412
+ internal_offset = 0
413
+ for start, end in digit_offset_tmp:
414
+ digit_offset.append(
415
+ (
416
+ event_offset + start + internal_offset,
417
+ event_offset + end + (end - start) * 2 + internal_offset,
418
+ )
419
+ )
420
+ internal_offset += (end - start) * 2
421
+
422
+ col_event = re.sub(r"([0-9\.])", r" \1 ", col_event)
423
+
424
+ col_name_offset.append((len(event), len(event) + len(col_name)))
425
+ event += " " + col_name + " " + col_event
426
+ if len(event) > 0:
427
+ events.append(event[1:])
428
+ digit_offsets.append(digit_offset)
429
+ col_name_offsets.append(col_name_offset)
430
+
431
+ tokenized_events = tokenizer(
432
+ events,
433
+ add_special_tokens=True,
434
+ padding="max_length",
435
+ max_length=max_event_length,
436
+ truncation=True,
437
+ return_tensors="np",
438
+ return_token_type_ids=False,
439
+ return_attention_mask=True,
440
+ return_offsets_mapping=True,
441
+ )
442
+ lengths_before_padding = tokenized_events["attention_mask"].sum(axis=1)
443
+
444
+ input_ids = tokenized_events["input_ids"]
445
+ dpe_ids = np.zeros(input_ids.shape, dtype=int)
446
+ for i, digit_offset in enumerate(digit_offsets):
447
+ for start, end in digit_offset:
448
+ start_index, end_index = find_boundary_between(
449
+ tokenized_events[i].offsets[: lengths_before_padding[i] - 1],
450
+ start,
451
+ end,
452
+ )
453
+
454
+ # define dpe ids for digits found
455
+ num_digits = end_index - start_index
456
+ # 119: token id for "."
457
+ num_decimal_points = (input_ids[i][start_index:end_index] == 119).sum()
458
+
459
+ # integer without decimal point
460
+ # e.g., for "1 2 3 4 5", assign [10, 9, 8, 7, 6]
461
+ if num_decimal_points == 0:
462
+ dpe_ids[i][start_index:end_index] = list(range(num_digits + 5, 5, -1))
463
+ # floats
464
+ # e.g., for "1 2 3 4 5 . 6 7 8 9", assign [10, 9, 8, 7, 6, 5, 4, 3, 2, 1]
465
+ elif num_decimal_points == 1:
466
+ num_decimals = (
467
+ num_digits
468
+ - np.where(input_ids[i][start_index:end_index] == 119)[0][0] # 119: token id for "."
469
+ )
470
+ dpe_ids[i][start_index:end_index] = list(
471
+ range(num_digits + 5 - num_decimals, 5 - num_decimals, -1)
472
+ )
473
+ # 1 > decimal points where we cannot define dpe ids
474
+ else:
475
+ continue
476
+ # define type ids
477
+ # for column names: 2
478
+ # for column values (contents): 3
479
+ # for CLS tokens: 5
480
+ # for SEP tokens: 6
481
+ type_ids = np.zeros(input_ids.shape, dtype=int)
482
+ type_ids[:, 0] = 5 # CLS tokens
483
+ for i, col_name_offset in enumerate(col_name_offsets):
484
+ type_ids[i][lengths_before_padding[i] - 1] = 6 # SEP tokens
485
+ # fill with type ids for column values
486
+ type_ids[i][1 : lengths_before_padding[i] - 1] = 3
487
+ for start, end in col_name_offset:
488
+ start_index, end_index = find_boundary_between(
489
+ tokenized_events[i].offsets[1 : lengths_before_padding[i] - 1],
490
+ start,
491
+ end,
492
+ )
493
+ # the first offset is always (0, 0) for CLS token, so we adjust it
494
+ start_index += 1
495
+ end_index += 1
496
+ # finally replace with type ids for column names
497
+ type_ids[i][start_index:end_index] = 2
498
+
499
+ return np.stack([input_ids, type_ids, dpe_ids], axis=1).astype(np.uint16)
500
+
501
+ events_data = []
502
+ worker_id = multiprocessing.current_process().name.split("-")[-1]
503
+ if worker_id == "MainProcess":
504
+ worker_id = 0
505
+ else:
506
+ # worker_id is incremental for every generated pool, so divide with num_shards
507
+ worker_id = (int(worker_id) - 1) % num_shards # 1-based -> 0-based indexing
508
+ if worker_id == 0:
509
+ progress_bar = tqdm(df_chunk.iter_rows(), total=len(df_chunk))
510
+ progress_bar.set_description(f"Processing from worker-{worker_id}:")
511
+ else:
512
+ progress_bar = df_chunk.iter_rows()
513
+
514
+ for row in progress_bar:
515
+ events_data.append(meds_to_remed_unit(row))
516
+ data_length = list(map(len, events_data))
517
+ data_index_offset = np.zeros(len(data_length), dtype=np.int64)
518
+ data_index_offset[1:] = np.cumsum(data_length[:-1])
519
+ data_index = pl.Series(
520
+ "data_index",
521
+ map(
522
+ lambda x: [data_index_offset[x] + y for y in range(data_length[x])],
523
+ range(len(data_length)),
524
+ ),
525
+ )
526
+ events_data = np.concatenate(events_data)
527
+
528
+ df_chunk = df_chunk.select(["subject_id", "cohort_end", "cohort_label", "time"])
529
+ df_chunk = df_chunk.insert_column(4, data_index)
530
+ df_chunk = df_chunk.explode(["cohort_end", "cohort_label"])
531
+ df_chunk = df_chunk.group_by(
532
+ # ["subject_id", "cohort_start", "cohort_end", "cohort_label"], maintain_order=True
533
+ ["subject_id", "cohort_end", "cohort_label"],
534
+ maintain_order=True,
535
+ ).agg(pl.all())
536
+
537
+ df_chunk = df_chunk.sort(by=["subject_id", "cohort_end"])
538
+ # regard {subject_id} as {cohort_id}: {subject_id}_{cohort_number}
539
+ df_chunk = df_chunk.with_columns(pl.col("subject_id").cum_count().over("subject_id").alias("suffix"))
540
+ df_chunk = df_chunk.with_columns(
541
+ (pl.col("subject_id").cast(str) + "_" + pl.col("suffix").cast(str)).alias("subject_id")
542
+ )
543
+ # data = data.drop("suffix", "cohort_start", "cohort_end")
544
+ df_chunk = df_chunk.drop("suffix", "cohort_end")
545
+
546
+ length_per_subject = {}
547
+ progress_bar = tqdm(
548
+ df_chunk.iter_rows(),
549
+ total=len(df_chunk),
550
+ desc=f"Writing data from worker-{worker_id}:",
551
+ )
552
+
553
+ for sample in progress_bar:
554
+ with h5py.File(str(output_dir / output_name / f"{worker_id}.h5"), "a") as f:
555
+ if "ehr" in f:
556
+ result = f["ehr"]
557
+ else:
558
+ result = f.create_group("ehr")
559
+
560
+ sample_result = result.create_group(sample[0])
561
+
562
+ data_indices = np.concatenate(sample[3])
563
+ data = events_data[data_indices]
564
+ sample_result.create_dataset("hi", data=data, dtype="i2", compression="lzf", shuffle=True)
565
+
566
+ times = np.concatenate(sample[2])
567
+ times = [datetime.strptime(x, "%Y-%m-%d %H:%M:%S") for x in times]
568
+ times = np.cumsum(np.diff(times))
569
+ times = list(map(lambda x: round(x.total_seconds() / 60), times))
570
+ times = np.array([0] + times)
571
+
572
+ sample_result.create_dataset("time", data=times, dtype="i")
573
+ sample_result.create_dataset("label", data=int(sample[1]))
574
+
575
+ length_per_subject[sample[0]] = (len(times), worker_id)
576
+ del df_chunk
577
+
578
+ return length_per_subject
579
+
580
+
581
+ if __name__ == "__main__":
582
+ parser = get_parser()
583
+ args = parser.parse_args()
584
+ main(args)