genhpf 1.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genhpf/__init__.py +9 -0
- genhpf/configs/__init__.py +23 -0
- genhpf/configs/config.yaml +8 -0
- genhpf/configs/configs.py +240 -0
- genhpf/configs/constants.py +29 -0
- genhpf/configs/initialize.py +58 -0
- genhpf/configs/utils.py +29 -0
- genhpf/criterions/__init__.py +74 -0
- genhpf/criterions/binary_cross_entropy.py +114 -0
- genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
- genhpf/criterions/criterion.py +87 -0
- genhpf/criterions/cross_entropy.py +202 -0
- genhpf/criterions/multi_task_criterion.py +177 -0
- genhpf/criterions/simclr_criterion.py +84 -0
- genhpf/criterions/wav2vec2_criterion.py +130 -0
- genhpf/datasets/__init__.py +84 -0
- genhpf/datasets/dataset.py +109 -0
- genhpf/datasets/genhpf_dataset.py +451 -0
- genhpf/datasets/meds_dataset.py +232 -0
- genhpf/loggings/__init__.py +0 -0
- genhpf/loggings/meters.py +374 -0
- genhpf/loggings/metrics.py +155 -0
- genhpf/loggings/progress_bar.py +445 -0
- genhpf/models/__init__.py +73 -0
- genhpf/models/genhpf.py +244 -0
- genhpf/models/genhpf_mlm.py +64 -0
- genhpf/models/genhpf_predictor.py +73 -0
- genhpf/models/genhpf_simclr.py +58 -0
- genhpf/models/genhpf_wav2vec2.py +304 -0
- genhpf/modules/__init__.py +15 -0
- genhpf/modules/gather_layer.py +23 -0
- genhpf/modules/grad_multiply.py +12 -0
- genhpf/modules/gumbel_vector_quantizer.py +204 -0
- genhpf/modules/identity_layer.py +8 -0
- genhpf/modules/layer_norm.py +27 -0
- genhpf/modules/positional_encoding.py +24 -0
- genhpf/scripts/__init__.py +0 -0
- genhpf/scripts/preprocess/__init__.py +0 -0
- genhpf/scripts/preprocess/genhpf/README.md +75 -0
- genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
- genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
- genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
- genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
- genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
- genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
- genhpf/scripts/preprocess/genhpf/main.py +175 -0
- genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
- genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
- genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
- genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
- genhpf/scripts/preprocess/manifest.py +83 -0
- genhpf/scripts/preprocess/preprocess_meds.py +674 -0
- genhpf/scripts/test.py +264 -0
- genhpf/scripts/train.py +365 -0
- genhpf/trainer.py +370 -0
- genhpf/utils/checkpoint_utils.py +171 -0
- genhpf/utils/data_utils.py +130 -0
- genhpf/utils/distributed_utils.py +497 -0
- genhpf/utils/file_io.py +170 -0
- genhpf/utils/pdb.py +38 -0
- genhpf/utils/utils.py +204 -0
- genhpf-1.0.11.dist-info/LICENSE +21 -0
- genhpf-1.0.11.dist-info/METADATA +202 -0
- genhpf-1.0.11.dist-info/RECORD +67 -0
- genhpf-1.0.11.dist-info/WHEEL +5 -0
- genhpf-1.0.11.dist-info/entry_points.txt +6 -0
- genhpf-1.0.11.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,674 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import logging
|
|
3
|
+
import glob
|
|
4
|
+
import multiprocessing
|
|
5
|
+
import os
|
|
6
|
+
import re
|
|
7
|
+
import shutil
|
|
8
|
+
from argparse import ArgumentParser
|
|
9
|
+
import argparse
|
|
10
|
+
from bisect import bisect_left, bisect_right
|
|
11
|
+
from datetime import datetime
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
|
|
14
|
+
import h5py
|
|
15
|
+
import numpy as np
|
|
16
|
+
import pandas as pd
|
|
17
|
+
import polars as pl
|
|
18
|
+
from tqdm import tqdm
|
|
19
|
+
from transformers import AutoTokenizer
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
logger.setLevel(logging.INFO)
|
|
23
|
+
|
|
24
|
+
pool_manager = multiprocessing.Manager()
|
|
25
|
+
warned_codes = pool_manager.list()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def find_boundary_between(tuples_list, start, end):
|
|
29
|
+
starts = [s for s, e in tuples_list]
|
|
30
|
+
ends = [e for s, e in tuples_list]
|
|
31
|
+
|
|
32
|
+
start_index = bisect_left(starts, start)
|
|
33
|
+
end_index = bisect_right(ends, end)
|
|
34
|
+
assert start_index < end_index
|
|
35
|
+
|
|
36
|
+
return start_index, end_index
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def get_parser():
|
|
40
|
+
parser = ArgumentParser()
|
|
41
|
+
parser.add_argument(
|
|
42
|
+
"root",
|
|
43
|
+
help="path to MEDS dataset. it can be either of directory or the exact file path "
|
|
44
|
+
"with the file extension. if provided with directory, it tries to scan *.csv or "
|
|
45
|
+
"*.parquet files contained in the directory, including sub-directories, to process "
|
|
46
|
+
"all of them.",
|
|
47
|
+
)
|
|
48
|
+
parser.add_argument(
|
|
49
|
+
"--metadata_dir",
|
|
50
|
+
help="path to metadata directory for the input MEDS dataset, which contains codes.parquet",
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
parser.add_argument(
|
|
54
|
+
"--birth_code", type=str, default="MEDS_BIRTH", help="string code for the birth event in the dataset."
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
parser.add_argument(
|
|
58
|
+
"--cohort",
|
|
59
|
+
type=str,
|
|
60
|
+
help="path to the defined cohort, which must be a result of ACES. it can be either of "
|
|
61
|
+
"directory or the exact file path that has the same extension with the MEDS dataset "
|
|
62
|
+
"to be processed. the file structure of this cohort directory should be the same with "
|
|
63
|
+
"the provided MEDS dataset directory to match each cohort to its corresponding shard "
|
|
64
|
+
"data.",
|
|
65
|
+
)
|
|
66
|
+
parser.add_argument(
|
|
67
|
+
"--cohort_label_name",
|
|
68
|
+
type=str,
|
|
69
|
+
default="boolean_value",
|
|
70
|
+
help="column name in the cohort dataframe to be used for label",
|
|
71
|
+
)
|
|
72
|
+
parser.add_argument(
|
|
73
|
+
"--output_dir",
|
|
74
|
+
type=str,
|
|
75
|
+
default="outputs",
|
|
76
|
+
help="directory to save processed outputs.",
|
|
77
|
+
)
|
|
78
|
+
parser.add_argument(
|
|
79
|
+
"--skip-if-exists",
|
|
80
|
+
action="store_true",
|
|
81
|
+
help="whether or not to skip the processing if the output directory already "
|
|
82
|
+
"exists.",
|
|
83
|
+
)
|
|
84
|
+
parser.add_argument(
|
|
85
|
+
"--rebase",
|
|
86
|
+
action="store_true",
|
|
87
|
+
help="whether or not to rebase the output directory if exists.",
|
|
88
|
+
)
|
|
89
|
+
parser.add_argument(
|
|
90
|
+
"--debug",
|
|
91
|
+
type=lambda x: {'true': True, 'false': False}[x.lower()],
|
|
92
|
+
default=False,
|
|
93
|
+
help="whether or not to enable the debug mode, which forces the script to be run with "
|
|
94
|
+
"only one worker."
|
|
95
|
+
)
|
|
96
|
+
parser.add_argument(
|
|
97
|
+
"--workers",
|
|
98
|
+
metavar="N",
|
|
99
|
+
default=1,
|
|
100
|
+
type=int,
|
|
101
|
+
help="number of parallel workers.",
|
|
102
|
+
)
|
|
103
|
+
parser.add_argument(
|
|
104
|
+
"--max_event_length",
|
|
105
|
+
metavar="N",
|
|
106
|
+
default=128,
|
|
107
|
+
type=int,
|
|
108
|
+
help="maximum number of tokens in an event.",
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
parser.add_argument(
|
|
112
|
+
"--mimic_dir",
|
|
113
|
+
default=None,
|
|
114
|
+
help="path to directory for MIMIC-IV database containing hosp/ and icu/ as a subdirectory. "
|
|
115
|
+
"this is used for addressing missing descriptions in the metadata for MIMIC-IV codes.",
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
return parser
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def main():
|
|
122
|
+
parser = get_parser()
|
|
123
|
+
args = parser.parse_args()
|
|
124
|
+
root_path = Path(args.root)
|
|
125
|
+
output_dir = Path(args.output_dir)
|
|
126
|
+
metadata_dir = Path(args.metadata_dir)
|
|
127
|
+
mimic_dir = Path(args.mimic_dir) if args.mimic_dir is not None else None
|
|
128
|
+
|
|
129
|
+
num_workers = max(args.workers, 1)
|
|
130
|
+
if args.debug:
|
|
131
|
+
print("debug mode is ON")
|
|
132
|
+
num_workers = 1
|
|
133
|
+
os.environ["RAYON_RS_NUM_CPUS"] = "1"
|
|
134
|
+
else:
|
|
135
|
+
cpu_count = multiprocessing.cpu_count()
|
|
136
|
+
if num_workers > cpu_count:
|
|
137
|
+
logger.warning(
|
|
138
|
+
f"Number of workers (--workers) is greater than the number of available CPUs "
|
|
139
|
+
f"({cpu_count}). Setting the number of workers to {cpu_count}."
|
|
140
|
+
)
|
|
141
|
+
num_workers = cpu_count
|
|
142
|
+
|
|
143
|
+
if root_path.is_dir():
|
|
144
|
+
data_paths = glob.glob(str(root_path / "**/*.csv"), recursive=True)
|
|
145
|
+
if len(data_paths) == 0:
|
|
146
|
+
data_paths = glob.glob(str(root_path / "**/*.parquet"), recursive=True)
|
|
147
|
+
if len(data_paths) == 0:
|
|
148
|
+
raise ValueError("Data directory does not contain any supported file formats: .csv or .parquet")
|
|
149
|
+
else:
|
|
150
|
+
data_paths = [root_path]
|
|
151
|
+
|
|
152
|
+
if not output_dir.exists():
|
|
153
|
+
output_dir.mkdir(parents=True)
|
|
154
|
+
else:
|
|
155
|
+
if args.rebase:
|
|
156
|
+
shutil.rmtree(output_dir)
|
|
157
|
+
output_dir.mkdir(parents=True)
|
|
158
|
+
elif output_dir.exists():
|
|
159
|
+
if args.skip_if_exists:
|
|
160
|
+
ls = glob.glob(str(output_dir / "**/*"), recursive=True)
|
|
161
|
+
expected_files = []
|
|
162
|
+
for subset in set(os.path.dirname(x) for x in data_paths):
|
|
163
|
+
expected_files.extend([
|
|
164
|
+
os.path.join(str(output_dir), os.path.basename(subset), f"{i}.h5")
|
|
165
|
+
for i in range(num_workers)
|
|
166
|
+
])
|
|
167
|
+
if set(expected_files).issubset(set(ls)):
|
|
168
|
+
logger.info(
|
|
169
|
+
f"Output directory already contains the expected files. Skipping the "
|
|
170
|
+
"processing as --skip-if-exists is set. If you want to rebase the directory, "
|
|
171
|
+
"please run the script with --rebase."
|
|
172
|
+
)
|
|
173
|
+
return
|
|
174
|
+
else:
|
|
175
|
+
raise ValueError(
|
|
176
|
+
f"File exists: '{str(output_dir.resolve())}'. If you want to rebase the "
|
|
177
|
+
"directory automatically, please run the script with --rebase."
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
label_col_name = args.cohort_label_name
|
|
181
|
+
|
|
182
|
+
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
|
|
183
|
+
|
|
184
|
+
codes_metadata = pl.read_parquet(metadata_dir / "codes.parquet").to_pandas()
|
|
185
|
+
codes_metadata = codes_metadata.set_index("code")["description"].to_dict()
|
|
186
|
+
# do not allow to use static events or birth event
|
|
187
|
+
birth_code = args.birth_code
|
|
188
|
+
# if birth_code not in codes_metadata:
|
|
189
|
+
# print(
|
|
190
|
+
# f'"{birth_code}" is not found in the codes metadata, which may lead to '
|
|
191
|
+
# "unexpected results since we currently exclude this event from the input data. "
|
|
192
|
+
# )
|
|
193
|
+
|
|
194
|
+
if mimic_dir is not None:
|
|
195
|
+
d_items = pd.read_csv(mimic_dir / "icu" / "d_items.csv.gz")
|
|
196
|
+
d_items["itemid"] = d_items["itemid"].astype("str")
|
|
197
|
+
d_items = d_items.set_index("itemid")["label"].to_dict()
|
|
198
|
+
d_labitems = pd.read_csv(mimic_dir / "hosp" / "d_labitems.csv.gz")
|
|
199
|
+
d_labitems["itemid"] = d_labitems["itemid"].astype("str")
|
|
200
|
+
d_labitems = d_labitems.set_index("itemid")["label"].to_dict()
|
|
201
|
+
else:
|
|
202
|
+
d_items = None
|
|
203
|
+
d_labitems = None
|
|
204
|
+
|
|
205
|
+
max_event_length = args.max_event_length
|
|
206
|
+
|
|
207
|
+
progress_bar = tqdm(data_paths, total=len(data_paths))
|
|
208
|
+
for data_path in progress_bar:
|
|
209
|
+
progress_bar.set_description(str(data_path))
|
|
210
|
+
|
|
211
|
+
data_path = Path(data_path)
|
|
212
|
+
subdir = data_path.relative_to(root_path).parent
|
|
213
|
+
if data_path.suffix == ".csv":
|
|
214
|
+
data = pl.scan_csv(
|
|
215
|
+
data_path,
|
|
216
|
+
low_memory=True if args.debug else False,
|
|
217
|
+
)
|
|
218
|
+
elif data_path.suffix == ".parquet":
|
|
219
|
+
data = pl.scan_parquet(
|
|
220
|
+
data_path,
|
|
221
|
+
parallel="none" if args.debug else "auto",
|
|
222
|
+
low_memory=True if args.debug else False,
|
|
223
|
+
)
|
|
224
|
+
else:
|
|
225
|
+
raise ValueError(f"Unsupported file format: {data_path.suffix}")
|
|
226
|
+
|
|
227
|
+
data = data.with_columns(
|
|
228
|
+
pl.when(pl.col("code") == birth_code).then(None).otherwise(pl.col("time")).alias("time")
|
|
229
|
+
)
|
|
230
|
+
data = data.drop_nulls(subset=["subject_id", "time"])
|
|
231
|
+
|
|
232
|
+
cohort_path = Path(args.cohort) / subdir / data_path.name
|
|
233
|
+
|
|
234
|
+
if cohort_path.suffix == ".csv":
|
|
235
|
+
cohort = pl.scan_csv(cohort_path)
|
|
236
|
+
elif cohort_path.suffix == ".parquet":
|
|
237
|
+
cohort = pl.scan_parquet(cohort_path)
|
|
238
|
+
else:
|
|
239
|
+
raise ValueError(f"Unsupported file format: {cohort_path.suffix}")
|
|
240
|
+
|
|
241
|
+
cohort = cohort.drop_nulls(label_col_name)
|
|
242
|
+
cohort = cohort.unique()
|
|
243
|
+
|
|
244
|
+
cohort = cohort.select(
|
|
245
|
+
[
|
|
246
|
+
pl.col("subject_id"),
|
|
247
|
+
pl.col(label_col_name),
|
|
248
|
+
# pl.col("input.end_summary").struct.field("timestamp_at_start").alias("starttime"),
|
|
249
|
+
pl.col("prediction_time").alias("endtime"),
|
|
250
|
+
]
|
|
251
|
+
)
|
|
252
|
+
cohort = (
|
|
253
|
+
cohort.group_by("subject_id", maintain_order=True)
|
|
254
|
+
.agg(pl.col(["endtime", label_col_name]))
|
|
255
|
+
.collect()
|
|
256
|
+
) # omitted "starttime"
|
|
257
|
+
cohort_dict = {
|
|
258
|
+
x["subject_id"]: {
|
|
259
|
+
# "starttime": x["starttime"],
|
|
260
|
+
"endtime": x["endtime"],
|
|
261
|
+
"label": x[label_col_name],
|
|
262
|
+
}
|
|
263
|
+
for x in cohort.iter_rows(named=True)
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
def extract_cohort(row):
|
|
267
|
+
subject_id = row["subject_id"]
|
|
268
|
+
time = row["time"]
|
|
269
|
+
if subject_id not in cohort_dict:
|
|
270
|
+
# return {"cohort_start": None, "cohort_end": None, "cohort_label": None}
|
|
271
|
+
return {"cohort_end": None, "cohort_label": None}
|
|
272
|
+
|
|
273
|
+
cohort_criteria = cohort_dict[subject_id]
|
|
274
|
+
# starts = cohort_criteria["starttime"]
|
|
275
|
+
ends = cohort_criteria["endtime"]
|
|
276
|
+
labels = cohort_criteria["label"]
|
|
277
|
+
|
|
278
|
+
# for start, end, label in zip(starts, ends, labels):
|
|
279
|
+
# if start <= time and time <= end:
|
|
280
|
+
# return {"cohort_start": start, "cohort_end": end, "cohort_label": label}
|
|
281
|
+
|
|
282
|
+
# assume it is possible that each event goes into multiple different cohorts
|
|
283
|
+
cohort_ends = []
|
|
284
|
+
cohort_labels = []
|
|
285
|
+
for end, label in zip(ends, labels):
|
|
286
|
+
if time <= end:
|
|
287
|
+
# return {"cohort_start": start, "cohort_end": end, "cohort_label": label}
|
|
288
|
+
cohort_ends.append(end)
|
|
289
|
+
cohort_labels.append(label)
|
|
290
|
+
|
|
291
|
+
if len(cohort_ends) > 0:
|
|
292
|
+
return {"cohort_end": cohort_ends, "cohort_label": cohort_labels}
|
|
293
|
+
else:
|
|
294
|
+
# return {"cohort_start": None, "cohort_end": None, "cohort_label": None}
|
|
295
|
+
return {"cohort_end": None, "cohort_label": None}
|
|
296
|
+
|
|
297
|
+
data = data.group_by(["subject_id", "time"], maintain_order=True).agg(pl.all())
|
|
298
|
+
data = (
|
|
299
|
+
data.with_columns(
|
|
300
|
+
pl.struct(["subject_id", "time"])
|
|
301
|
+
.map_elements(
|
|
302
|
+
extract_cohort,
|
|
303
|
+
return_dtype=pl.Struct(
|
|
304
|
+
{
|
|
305
|
+
"cohort_end": pl.List(pl.Datetime()),
|
|
306
|
+
"cohort_label": pl.List(pl.Boolean),
|
|
307
|
+
}
|
|
308
|
+
),
|
|
309
|
+
)
|
|
310
|
+
.alias("cohort_criteria")
|
|
311
|
+
)
|
|
312
|
+
.unnest("cohort_criteria")
|
|
313
|
+
.collect()
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
data = data.drop_nulls("cohort_label")
|
|
317
|
+
|
|
318
|
+
data = data.with_columns(pl.col("time").dt.strftime("%Y-%m-%d %H:%M:%S").cast(pl.List(str)))
|
|
319
|
+
data = data.with_columns(
|
|
320
|
+
pl.col("time").list.sample(n=pl.col("code").list.len(), with_replacement=True)
|
|
321
|
+
)
|
|
322
|
+
if args.debug:
|
|
323
|
+
data = data[:5000]
|
|
324
|
+
|
|
325
|
+
if str(subdir) != ".":
|
|
326
|
+
output_name = str(subdir)
|
|
327
|
+
else:
|
|
328
|
+
output_name = data_path.stem
|
|
329
|
+
|
|
330
|
+
if not os.path.exists(output_dir / output_name):
|
|
331
|
+
os.makedirs(output_dir / output_name)
|
|
332
|
+
|
|
333
|
+
with open(str(output_dir / (output_name + ".tsv")), "a") as manifest_f:
|
|
334
|
+
if os.path.getsize(output_dir / (output_name + ".tsv")) == 0:
|
|
335
|
+
manifest_f.write(f"{output_dir}/{output_name}\n")
|
|
336
|
+
|
|
337
|
+
must_have_columns = [
|
|
338
|
+
"subject_id",
|
|
339
|
+
"cohort_end",
|
|
340
|
+
"cohort_label",
|
|
341
|
+
"time",
|
|
342
|
+
"code",
|
|
343
|
+
"numeric_value",
|
|
344
|
+
]
|
|
345
|
+
rest_of_columns = [x for x in data.columns if x not in must_have_columns]
|
|
346
|
+
column_name_idcs = {col: i for i, col in enumerate(data.columns)}
|
|
347
|
+
|
|
348
|
+
meds_to_remed_partial = functools.partial(
|
|
349
|
+
meds_to_remed,
|
|
350
|
+
tokenizer,
|
|
351
|
+
rest_of_columns,
|
|
352
|
+
column_name_idcs,
|
|
353
|
+
codes_metadata,
|
|
354
|
+
output_dir,
|
|
355
|
+
output_name,
|
|
356
|
+
num_workers,
|
|
357
|
+
d_items,
|
|
358
|
+
d_labitems,
|
|
359
|
+
warned_codes,
|
|
360
|
+
max_event_length,
|
|
361
|
+
args.debug,
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
# meds --> remed
|
|
365
|
+
logger.info(f"Start processing {data_path}")
|
|
366
|
+
if num_workers <= 1:
|
|
367
|
+
length_per_subject_gathered = [meds_to_remed_partial(data)]
|
|
368
|
+
del data
|
|
369
|
+
else:
|
|
370
|
+
subject_ids = data["subject_id"].unique().to_list()
|
|
371
|
+
n = num_workers
|
|
372
|
+
subject_id_chunks = [subject_ids[i::n] for i in range(n)]
|
|
373
|
+
data_chunks = []
|
|
374
|
+
for subject_id_chunk in subject_id_chunks:
|
|
375
|
+
data_chunks.append(data.filter(pl.col("subject_id").is_in(subject_id_chunk)))
|
|
376
|
+
del data
|
|
377
|
+
|
|
378
|
+
num_valid_data_chunks = sum(map(lambda x: len(x) > 0, data_chunks))
|
|
379
|
+
if num_valid_data_chunks < num_workers:
|
|
380
|
+
raise ValueError(
|
|
381
|
+
"Number of valid data chunks (= number of unique subjects) were smaller "
|
|
382
|
+
"than the specified num workers (--workers) due to the small size of data. "
|
|
383
|
+
"Consider reducing the number of workers."
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
pool = multiprocessing.get_context("spawn").Pool(processes=num_workers)
|
|
387
|
+
# the order is preserved
|
|
388
|
+
length_per_subject_gathered = pool.map(meds_to_remed_partial, data_chunks)
|
|
389
|
+
pool.close()
|
|
390
|
+
pool.join()
|
|
391
|
+
del data_chunks
|
|
392
|
+
|
|
393
|
+
if len(length_per_subject_gathered) != num_workers:
|
|
394
|
+
raise ValueError(
|
|
395
|
+
"Number of processed workers were smaller than the specified num workers "
|
|
396
|
+
"(--workers) due to the small size of data. Consider reducing the number of "
|
|
397
|
+
"workers."
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
for length_per_subject in length_per_subject_gathered:
|
|
401
|
+
for subject_id, (length, shard_id) in length_per_subject.items():
|
|
402
|
+
manifest_f.write(f"{subject_id}\t{length}\t{shard_id}\n")
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
def meds_to_remed(
|
|
406
|
+
tokenizer,
|
|
407
|
+
rest_of_columns,
|
|
408
|
+
column_name_idcs,
|
|
409
|
+
codes_metadata,
|
|
410
|
+
output_dir,
|
|
411
|
+
output_name,
|
|
412
|
+
num_shards,
|
|
413
|
+
d_items,
|
|
414
|
+
d_labitems,
|
|
415
|
+
warned_codes,
|
|
416
|
+
max_event_length,
|
|
417
|
+
debug,
|
|
418
|
+
df_chunk,
|
|
419
|
+
):
|
|
420
|
+
code_matching_pattern = re.compile(r"\d+")
|
|
421
|
+
|
|
422
|
+
def meds_to_remed_unit(row):
|
|
423
|
+
events = []
|
|
424
|
+
digit_offsets = []
|
|
425
|
+
col_name_offsets = []
|
|
426
|
+
for event_index in range(len(row[column_name_idcs["code"]])):
|
|
427
|
+
event = ""
|
|
428
|
+
digit_offset = []
|
|
429
|
+
col_name_offset = []
|
|
430
|
+
for col_name in ["code", "numeric_value"] + rest_of_columns:
|
|
431
|
+
# do not process something like "icustay_id" or "hadm_id"
|
|
432
|
+
if "id" in col_name:
|
|
433
|
+
continue
|
|
434
|
+
|
|
435
|
+
col_event = row[column_name_idcs[col_name]][event_index]
|
|
436
|
+
# print(f"col_name: {col_name}")
|
|
437
|
+
# print(row[column_name_idcs[col_name]][event_index])
|
|
438
|
+
if col_event is not None:
|
|
439
|
+
col_event = str(col_event)
|
|
440
|
+
# if col_name == "code":
|
|
441
|
+
# if col_event in codes_metadata and codes_metadata[col_event] != "":
|
|
442
|
+
# col_event = codes_metadata[col_event]
|
|
443
|
+
# else:
|
|
444
|
+
# do_break = False
|
|
445
|
+
# safely resolve description: check for None or empty string
|
|
446
|
+
if col_name == "code":
|
|
447
|
+
desc = codes_metadata.get(col_event)
|
|
448
|
+
if desc is not None and desc != "":
|
|
449
|
+
col_event = desc
|
|
450
|
+
else:
|
|
451
|
+
do_break = False
|
|
452
|
+
items = col_event.split("//")
|
|
453
|
+
is_code = [bool(code_matching_pattern.fullmatch(item)) for item in items]
|
|
454
|
+
if True in is_code:
|
|
455
|
+
if d_items is not None and d_labitems is not None:
|
|
456
|
+
code_idx = is_code.index(True)
|
|
457
|
+
code = items[code_idx]
|
|
458
|
+
|
|
459
|
+
if code in d_items:
|
|
460
|
+
desc = d_items[code]
|
|
461
|
+
elif code in d_labitems:
|
|
462
|
+
desc = d_labitems[code]
|
|
463
|
+
else:
|
|
464
|
+
do_break = True
|
|
465
|
+
|
|
466
|
+
if not do_break:
|
|
467
|
+
items[code_idx] = desc
|
|
468
|
+
col_event = "//".join(items)
|
|
469
|
+
else:
|
|
470
|
+
do_break = True
|
|
471
|
+
if do_break and col_event not in warned_codes:
|
|
472
|
+
warned_codes.append(col_event)
|
|
473
|
+
logger.warning(
|
|
474
|
+
"The dataset contains some codes that are not specified in "
|
|
475
|
+
"the codes metadata, which may not be intended. Note that we "
|
|
476
|
+
f"process this code as it is for now: {col_event}."
|
|
477
|
+
)
|
|
478
|
+
else:
|
|
479
|
+
col_event = re.sub(
|
|
480
|
+
r"\d*\.\d+",
|
|
481
|
+
lambda x: str(round(float(x.group(0)), 4)),
|
|
482
|
+
col_event,
|
|
483
|
+
)
|
|
484
|
+
event_offset = len(event) + len(col_name) + 1
|
|
485
|
+
digit_offset_tmp = [
|
|
486
|
+
g.span() for g in re.finditer(r"([0-9]+([.][0-9]*)?|[0-9]+|\.+)", col_event)
|
|
487
|
+
]
|
|
488
|
+
|
|
489
|
+
internal_offset = 0
|
|
490
|
+
for start, end in digit_offset_tmp:
|
|
491
|
+
digit_offset.append(
|
|
492
|
+
(
|
|
493
|
+
event_offset + start + internal_offset,
|
|
494
|
+
event_offset + end + (end - start) * 2 + internal_offset,
|
|
495
|
+
)
|
|
496
|
+
)
|
|
497
|
+
internal_offset += (end - start) * 2
|
|
498
|
+
|
|
499
|
+
col_event = re.sub(r"([0-9\.])", r" \1 ", col_event)
|
|
500
|
+
if col_event is None:
|
|
501
|
+
logger.warning(f"Skipped col_event for col_name {col_name} because it is None")
|
|
502
|
+
continue
|
|
503
|
+
col_name_offset.append((len(event), len(event) + len(col_name)))
|
|
504
|
+
event += " " + col_name + " " + col_event
|
|
505
|
+
if len(event) > 0:
|
|
506
|
+
events.append(event[1:])
|
|
507
|
+
digit_offsets.append(digit_offset)
|
|
508
|
+
col_name_offsets.append(col_name_offset)
|
|
509
|
+
|
|
510
|
+
tokenized_events = tokenizer(
|
|
511
|
+
events,
|
|
512
|
+
add_special_tokens=True,
|
|
513
|
+
padding="max_length",
|
|
514
|
+
max_length=max_event_length,
|
|
515
|
+
truncation=True,
|
|
516
|
+
return_tensors="np",
|
|
517
|
+
return_token_type_ids=False,
|
|
518
|
+
return_attention_mask=True,
|
|
519
|
+
return_offsets_mapping=True,
|
|
520
|
+
)
|
|
521
|
+
lengths_before_padding = tokenized_events["attention_mask"].sum(axis=1)
|
|
522
|
+
|
|
523
|
+
input_ids = tokenized_events["input_ids"]
|
|
524
|
+
dpe_ids = np.zeros(input_ids.shape, dtype=int)
|
|
525
|
+
for i, digit_offset in enumerate(digit_offsets):
|
|
526
|
+
for start, end in digit_offset:
|
|
527
|
+
start_index, end_index = find_boundary_between(
|
|
528
|
+
tokenized_events[i].offsets[: lengths_before_padding[i] - 1],
|
|
529
|
+
start,
|
|
530
|
+
end,
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
# define dpe ids for digits found
|
|
534
|
+
num_digits = end_index - start_index
|
|
535
|
+
# 119: token id for "."
|
|
536
|
+
num_decimal_points = (input_ids[i][start_index:end_index] == 119).sum()
|
|
537
|
+
|
|
538
|
+
# integer without decimal point
|
|
539
|
+
# e.g., for "1 2 3 4 5", assign [10, 9, 8, 7, 6]
|
|
540
|
+
if num_decimal_points == 0:
|
|
541
|
+
dpe_ids[i][start_index:end_index] = list(range(num_digits + 5, 5, -1))
|
|
542
|
+
# floats
|
|
543
|
+
# e.g., for "1 2 3 4 5 . 6 7 8 9", assign [10, 9, 8, 7, 6, 5, 4, 3, 2, 1]
|
|
544
|
+
elif num_decimal_points == 1:
|
|
545
|
+
num_decimals = (
|
|
546
|
+
num_digits
|
|
547
|
+
- np.where(input_ids[i][start_index:end_index] == 119)[0][0] # 119: token id for "."
|
|
548
|
+
)
|
|
549
|
+
dpe_ids[i][start_index:end_index] = list(
|
|
550
|
+
range(num_digits + 5 - num_decimals, 5 - num_decimals, -1)
|
|
551
|
+
)
|
|
552
|
+
# 1 > decimal points where we cannot define dpe ids
|
|
553
|
+
else:
|
|
554
|
+
continue
|
|
555
|
+
# define type ids
|
|
556
|
+
# for column names: 2
|
|
557
|
+
# for column values (contents): 3
|
|
558
|
+
# for CLS tokens: 5
|
|
559
|
+
# for SEP tokens: 6
|
|
560
|
+
type_ids = np.zeros(input_ids.shape, dtype=int)
|
|
561
|
+
type_ids[:, 0] = 5 # CLS tokens
|
|
562
|
+
for i, col_name_offset in enumerate(col_name_offsets):
|
|
563
|
+
type_ids[i][lengths_before_padding[i] - 1] = 6 # SEP tokens
|
|
564
|
+
# fill with type ids for column values
|
|
565
|
+
type_ids[i][1 : lengths_before_padding[i] - 1] = 3
|
|
566
|
+
for start, end in col_name_offset:
|
|
567
|
+
start_index, end_index = find_boundary_between(
|
|
568
|
+
tokenized_events[i].offsets[1 : lengths_before_padding[i] - 1],
|
|
569
|
+
start,
|
|
570
|
+
end,
|
|
571
|
+
)
|
|
572
|
+
# the first offset is always (0, 0) for CLS token, so we adjust it
|
|
573
|
+
start_index += 1
|
|
574
|
+
end_index += 1
|
|
575
|
+
# finally replace with type ids for column names
|
|
576
|
+
type_ids[i][start_index:end_index] = 2
|
|
577
|
+
|
|
578
|
+
return np.stack([input_ids, type_ids, dpe_ids], axis=1).astype(np.uint16)
|
|
579
|
+
|
|
580
|
+
events_data = []
|
|
581
|
+
worker_id = multiprocessing.current_process().name.split("-")[-1]
|
|
582
|
+
if worker_id == "MainProcess":
|
|
583
|
+
worker_id = 0
|
|
584
|
+
else:
|
|
585
|
+
# worker_id is incremental for every generated pool, so divide with num_shards
|
|
586
|
+
worker_id = (int(worker_id) - 1) % num_shards # 1-based -> 0-based indexing
|
|
587
|
+
if worker_id == 0:
|
|
588
|
+
progress_bar = tqdm(df_chunk.iter_rows(), total=len(df_chunk))
|
|
589
|
+
progress_bar.set_description(f"Processing from worker-{worker_id}:")
|
|
590
|
+
else:
|
|
591
|
+
progress_bar = df_chunk.iter_rows()
|
|
592
|
+
|
|
593
|
+
for row in progress_bar:
|
|
594
|
+
events_data.append(meds_to_remed_unit(row))
|
|
595
|
+
data_length = list(map(len, events_data))
|
|
596
|
+
data_index_offset = np.zeros(len(data_length), dtype=np.int64)
|
|
597
|
+
data_index_offset[1:] = np.cumsum(data_length[:-1])
|
|
598
|
+
data_index = pl.Series(
|
|
599
|
+
"data_index",
|
|
600
|
+
map(
|
|
601
|
+
lambda x: [data_index_offset[x] + y for y in range(data_length[x])],
|
|
602
|
+
range(len(data_length)),
|
|
603
|
+
),
|
|
604
|
+
)
|
|
605
|
+
events_data = np.concatenate(events_data)
|
|
606
|
+
|
|
607
|
+
df_chunk = df_chunk.select(["subject_id", "cohort_end", "cohort_label", "time"])
|
|
608
|
+
df_chunk = df_chunk.insert_column(4, data_index)
|
|
609
|
+
df_chunk = df_chunk.explode(["cohort_end", "cohort_label"])
|
|
610
|
+
df_chunk = df_chunk.group_by(
|
|
611
|
+
# ["subject_id", "cohort_start", "cohort_end", "cohort_label"], maintain_order=True
|
|
612
|
+
["subject_id", "cohort_end", "cohort_label"],
|
|
613
|
+
maintain_order=True,
|
|
614
|
+
).agg(pl.all())
|
|
615
|
+
|
|
616
|
+
if debug:
|
|
617
|
+
print("debug_mode is on!")
|
|
618
|
+
df_chunk = df_chunk.with_columns(
|
|
619
|
+
[
|
|
620
|
+
pl.col("time").map_elements(lambda x: x[-100:], return_dtype=pl.List(pl.List(str))),
|
|
621
|
+
pl.col("data_index").map_elements(lambda x: x[-100:], return_dtype=pl.List(pl.List(int)))
|
|
622
|
+
]
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
df_chunk = df_chunk.sort(by=["subject_id", "cohort_end"])
|
|
626
|
+
# regard {subject_id} as {cohort_id}: {subject_id}_{cohort_number}
|
|
627
|
+
df_chunk = df_chunk.with_columns(pl.col("subject_id").cum_count().over("subject_id").alias("suffix"))
|
|
628
|
+
df_chunk = df_chunk.with_columns(
|
|
629
|
+
(pl.col("subject_id").cast(str) + "_" + pl.col("suffix").cast(str)).alias("subject_id")
|
|
630
|
+
)
|
|
631
|
+
# data = data.drop("suffix", "cohort_start", "cohort_end")
|
|
632
|
+
df_chunk = df_chunk.drop("suffix", "cohort_end")
|
|
633
|
+
|
|
634
|
+
length_per_subject = {}
|
|
635
|
+
progress_bar = tqdm(
|
|
636
|
+
df_chunk.iter_rows(),
|
|
637
|
+
total=len(df_chunk),
|
|
638
|
+
desc=f"Writing data from worker-{worker_id}:",
|
|
639
|
+
)
|
|
640
|
+
|
|
641
|
+
for sample in progress_bar:
|
|
642
|
+
with h5py.File(str(output_dir / output_name / f"{worker_id}.h5"), "a") as f:
|
|
643
|
+
if "ehr" in f:
|
|
644
|
+
result = f["ehr"]
|
|
645
|
+
else:
|
|
646
|
+
result = f.create_group("ehr")
|
|
647
|
+
|
|
648
|
+
sample_result = result.create_group(sample[0])
|
|
649
|
+
|
|
650
|
+
times = np.concatenate(sample[2])
|
|
651
|
+
data_indices = np.concatenate(sample[3])
|
|
652
|
+
if debug:
|
|
653
|
+
data_indices = data_indices[-100:]
|
|
654
|
+
times = times[-100:]
|
|
655
|
+
|
|
656
|
+
data = events_data[data_indices]
|
|
657
|
+
sample_result.create_dataset("hi", data=data, dtype="i2", compression="lzf", shuffle=True)
|
|
658
|
+
|
|
659
|
+
times = [datetime.strptime(x, "%Y-%m-%d %H:%M:%S") for x in times]
|
|
660
|
+
times = np.cumsum(np.diff(times))
|
|
661
|
+
times = list(map(lambda x: round(x.total_seconds() / 60), times))
|
|
662
|
+
times = np.array([0] + times)
|
|
663
|
+
|
|
664
|
+
sample_result.create_dataset("time", data=times, dtype="i")
|
|
665
|
+
sample_result.create_dataset("label", data=int(sample[1]))
|
|
666
|
+
|
|
667
|
+
length_per_subject[sample[0]] = (len(times), worker_id)
|
|
668
|
+
del df_chunk
|
|
669
|
+
|
|
670
|
+
return length_per_subject
|
|
671
|
+
|
|
672
|
+
|
|
673
|
+
if __name__ == "__main__":
|
|
674
|
+
main()
|