genhpf 1.0.2__py3-none-any.whl → 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of genhpf might be problematic. Click here for more details.
- genhpf/scripts/preprocess/genhpf/main.py +5 -4
- genhpf/scripts/preprocess/preprocess_meds.py +76 -22
- {genhpf-1.0.2.dist-info → genhpf-1.0.4.dist-info}/METADATA +2 -1
- {genhpf-1.0.2.dist-info → genhpf-1.0.4.dist-info}/RECORD +8 -8
- {genhpf-1.0.2.dist-info → genhpf-1.0.4.dist-info}/LICENSE +0 -0
- {genhpf-1.0.2.dist-info → genhpf-1.0.4.dist-info}/WHEEL +0 -0
- {genhpf-1.0.2.dist-info → genhpf-1.0.4.dist-info}/entry_points.txt +0 -0
- {genhpf-1.0.2.dist-info → genhpf-1.0.4.dist-info}/top_level.txt +0 -0
|
@@ -151,7 +151,10 @@ def get_parser():
|
|
|
151
151
|
return parser
|
|
152
152
|
|
|
153
153
|
|
|
154
|
-
def main(
|
|
154
|
+
def main():
|
|
155
|
+
parser = get_parser()
|
|
156
|
+
args = parser.parse_args()
|
|
157
|
+
|
|
155
158
|
if not os.path.exists(args.dest):
|
|
156
159
|
os.makedirs(args.dest)
|
|
157
160
|
|
|
@@ -169,6 +172,4 @@ def main(args):
|
|
|
169
172
|
|
|
170
173
|
|
|
171
174
|
if __name__ == "__main__":
|
|
172
|
-
|
|
173
|
-
args = parser.parse_args()
|
|
174
|
-
main(args)
|
|
175
|
+
main()
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
import functools
|
|
2
|
+
import logging
|
|
2
3
|
import glob
|
|
3
4
|
import multiprocessing
|
|
4
5
|
import os
|
|
5
6
|
import re
|
|
6
7
|
import shutil
|
|
7
|
-
import warnings
|
|
8
8
|
from argparse import ArgumentParser
|
|
9
9
|
from bisect import bisect_left, bisect_right
|
|
10
10
|
from datetime import datetime
|
|
@@ -17,6 +17,9 @@ import polars as pl
|
|
|
17
17
|
from tqdm import tqdm
|
|
18
18
|
from transformers import AutoTokenizer
|
|
19
19
|
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
logger.setLevel(logging.INFO)
|
|
22
|
+
|
|
20
23
|
pool_manager = multiprocessing.Manager()
|
|
21
24
|
warned_codes = pool_manager.list()
|
|
22
25
|
|
|
@@ -71,11 +74,24 @@ def get_parser():
|
|
|
71
74
|
default="outputs",
|
|
72
75
|
help="directory to save processed outputs.",
|
|
73
76
|
)
|
|
77
|
+
parser.add_argument(
|
|
78
|
+
"--skip-if-exists",
|
|
79
|
+
action="store_true",
|
|
80
|
+
help="whether or not to skip the processing if the output directory already "
|
|
81
|
+
"exists.",
|
|
82
|
+
)
|
|
74
83
|
parser.add_argument(
|
|
75
84
|
"--rebase",
|
|
76
85
|
action="store_true",
|
|
77
86
|
help="whether or not to rebase the output directory if exists.",
|
|
78
87
|
)
|
|
88
|
+
parser.add_argument(
|
|
89
|
+
"--debug",
|
|
90
|
+
type=bool,
|
|
91
|
+
default=False,
|
|
92
|
+
help="whether or not to enable the debug mode, which forces the script to be run with "
|
|
93
|
+
"only one worker."
|
|
94
|
+
)
|
|
79
95
|
parser.add_argument(
|
|
80
96
|
"--workers",
|
|
81
97
|
metavar="N",
|
|
@@ -101,23 +117,26 @@ def get_parser():
|
|
|
101
117
|
return parser
|
|
102
118
|
|
|
103
119
|
|
|
104
|
-
def main(
|
|
120
|
+
def main():
|
|
121
|
+
parser = get_parser()
|
|
122
|
+
args = parser.parse_args()
|
|
123
|
+
|
|
105
124
|
root_path = Path(args.root)
|
|
106
125
|
output_dir = Path(args.output_dir)
|
|
107
126
|
metadata_dir = Path(args.metadata_dir)
|
|
108
127
|
mimic_dir = Path(args.mimic_dir) if args.mimic_dir is not None else None
|
|
109
128
|
|
|
110
|
-
|
|
111
|
-
|
|
129
|
+
num_workers = max(args.workers, 1)
|
|
130
|
+
if args.debug:
|
|
131
|
+
num_workers = 1
|
|
112
132
|
else:
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
f"
|
|
118
|
-
"directory, please run the script with --rebase."
|
|
133
|
+
cpu_count = multiprocessing.cpu_count()
|
|
134
|
+
if num_workers > cpu_count:
|
|
135
|
+
logger.warning(
|
|
136
|
+
f"Number of workers (--workers) is greater than the number of available CPUs "
|
|
137
|
+
f"({cpu_count}). Setting the number of workers to {cpu_count}."
|
|
119
138
|
)
|
|
120
|
-
|
|
139
|
+
num_workers = cpu_count
|
|
121
140
|
|
|
122
141
|
if root_path.is_dir():
|
|
123
142
|
data_paths = glob.glob(str(root_path / "**/*.csv"), recursive=True)
|
|
@@ -128,6 +147,34 @@ def main(args):
|
|
|
128
147
|
else:
|
|
129
148
|
data_paths = [root_path]
|
|
130
149
|
|
|
150
|
+
if not output_dir.exists():
|
|
151
|
+
output_dir.mkdir()
|
|
152
|
+
else:
|
|
153
|
+
if args.rebase:
|
|
154
|
+
shutil.rmtree(output_dir)
|
|
155
|
+
output_dir.mkdir()
|
|
156
|
+
elif output_dir.exists():
|
|
157
|
+
if args.skip_if_exists:
|
|
158
|
+
ls = glob.glob(str(output_dir / "**/*"), recursive=True)
|
|
159
|
+
expected_files = []
|
|
160
|
+
for subset in set(os.path.dirname(x) for x in data_paths):
|
|
161
|
+
expected_files.extend([
|
|
162
|
+
os.path.join(str(output_dir), os.path.basename(subset), f"{i}.h5")
|
|
163
|
+
for i in range(num_workers)
|
|
164
|
+
])
|
|
165
|
+
if set(expected_files).issubset(set(ls)):
|
|
166
|
+
logger.info(
|
|
167
|
+
f"Output directory already contains the expected files. Skipping the "
|
|
168
|
+
"processing as --skip-if-exists is set. If you want to rebase the directory, "
|
|
169
|
+
"please run the script with --rebase."
|
|
170
|
+
)
|
|
171
|
+
return
|
|
172
|
+
else:
|
|
173
|
+
raise ValueError(
|
|
174
|
+
f"File exists: '{str(output_dir.resolve())}'. If you want to rebase the "
|
|
175
|
+
"directory automatically, please run the script with --rebase."
|
|
176
|
+
)
|
|
177
|
+
|
|
131
178
|
label_col_name = args.cohort_label_name
|
|
132
179
|
|
|
133
180
|
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
|
|
@@ -295,7 +342,7 @@ def main(args):
|
|
|
295
342
|
codes_metadata,
|
|
296
343
|
output_dir,
|
|
297
344
|
output_name,
|
|
298
|
-
|
|
345
|
+
num_workers,
|
|
299
346
|
d_items,
|
|
300
347
|
d_labitems,
|
|
301
348
|
warned_codes,
|
|
@@ -303,27 +350,36 @@ def main(args):
|
|
|
303
350
|
)
|
|
304
351
|
|
|
305
352
|
# meds --> remed
|
|
306
|
-
|
|
307
|
-
if
|
|
353
|
+
logger.info(f"Start processing {data_path}")
|
|
354
|
+
if num_workers <= 1:
|
|
308
355
|
length_per_subject_gathered = [meds_to_remed_partial(data)]
|
|
309
356
|
del data
|
|
310
357
|
else:
|
|
311
358
|
subject_ids = data["subject_id"].unique().to_list()
|
|
312
|
-
n =
|
|
359
|
+
n = num_workers
|
|
313
360
|
subject_id_chunks = [subject_ids[i::n] for i in range(n)]
|
|
314
361
|
data_chunks = []
|
|
315
362
|
for subject_id_chunk in subject_id_chunks:
|
|
316
363
|
data_chunks.append(data.filter(pl.col("subject_id").is_in(subject_id_chunk)))
|
|
317
364
|
del data
|
|
318
|
-
|
|
365
|
+
|
|
366
|
+
num_valid_data_chunks = sum(map(lambda x: len(x) > 0, data_chunks))
|
|
367
|
+
if num_valid_data_chunks < num_workers:
|
|
368
|
+
raise ValueError(
|
|
369
|
+
"Number of valid data chunks (= number of unique subjects) were smaller "
|
|
370
|
+
"than the specified num workers (--workers) due to the small size of data. "
|
|
371
|
+
"Consider reducing the number of workers."
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
pool = multiprocessing.get_context("spawn").Pool(processes=num_workers)
|
|
319
375
|
# the order is preserved
|
|
320
376
|
length_per_subject_gathered = pool.map(meds_to_remed_partial, data_chunks)
|
|
321
377
|
pool.close()
|
|
322
378
|
pool.join()
|
|
323
379
|
del data_chunks
|
|
324
380
|
|
|
325
|
-
if len(length_per_subject_gathered) !=
|
|
326
|
-
|
|
381
|
+
if len(length_per_subject_gathered) != num_workers:
|
|
382
|
+
raise ValueError(
|
|
327
383
|
"Number of processed workers were smaller than the specified num workers "
|
|
328
384
|
"(--workers) due to the small size of data. Consider reducing the number of "
|
|
329
385
|
"workers."
|
|
@@ -393,7 +449,7 @@ def meds_to_remed(
|
|
|
393
449
|
|
|
394
450
|
if do_break and col_event not in warned_codes:
|
|
395
451
|
warned_codes.append(col_event)
|
|
396
|
-
|
|
452
|
+
logger.warning(
|
|
397
453
|
"The dataset contains some codes that are not specified in "
|
|
398
454
|
"the codes metadata, which may not be intended. Note that we "
|
|
399
455
|
f"process this code as it is for now: {col_event}."
|
|
@@ -579,6 +635,4 @@ def meds_to_remed(
|
|
|
579
635
|
|
|
580
636
|
|
|
581
637
|
if __name__ == "__main__":
|
|
582
|
-
|
|
583
|
-
args = parser.parse_args()
|
|
584
|
-
main(args)
|
|
638
|
+
main()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: genhpf
|
|
3
|
-
Version: 1.0.
|
|
3
|
+
Version: 1.0.4
|
|
4
4
|
Summary: GenHPF: General Healthcare Predictive Framework with Multi-task Multi-source Learning
|
|
5
5
|
Author-email: Jungwoo Oh <ojw0123@kaist.ac.kr>, Kyunghoon Hur <pacesun@kaist.ac.kr>
|
|
6
6
|
License: MIT license
|
|
@@ -18,6 +18,7 @@ Requires-Dist: h5pickle==0.4.2
|
|
|
18
18
|
Requires-Dist: scikit-learn==1.6.1
|
|
19
19
|
Requires-Dist: pandas==2.2.3
|
|
20
20
|
Requires-Dist: polars==1.17.1
|
|
21
|
+
Requires-Dist: pyarrow==17.0.0
|
|
21
22
|
Provides-Extra: dev
|
|
22
23
|
Requires-Dist: pre-commit; extra == "dev"
|
|
23
24
|
Requires-Dist: black; extra == "dev"
|
|
@@ -40,10 +40,10 @@ genhpf/scripts/test.py,sha256=wWi7OLqxsW9blj21m3RTirvziQ5UpkjkngOkgkE3Vb4,10149
|
|
|
40
40
|
genhpf/scripts/train.py,sha256=5f5PYOkiW7BahbFArvdOguAzUdDnY4Urw7Nx3aJ4kjs,12488
|
|
41
41
|
genhpf/scripts/preprocess/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
42
42
|
genhpf/scripts/preprocess/manifest.py,sha256=ZIK16e4vs_cS2K_tM1GaT38hc1nBHk6JB9Uga6OjgU4,2711
|
|
43
|
-
genhpf/scripts/preprocess/preprocess_meds.py,sha256=
|
|
43
|
+
genhpf/scripts/preprocess/preprocess_meds.py,sha256=x5R4KDzzB-21IUHjfkyo4-Be1t9U4oaurjm5VhxZ5Rw,25050
|
|
44
44
|
genhpf/scripts/preprocess/genhpf/README.md,sha256=qtpM_ABJk5yI8xbsUj1sZ71yX5bybx9ZvAymo0Lh5Vc,2877
|
|
45
45
|
genhpf/scripts/preprocess/genhpf/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
46
|
-
genhpf/scripts/preprocess/genhpf/main.py,sha256=
|
|
46
|
+
genhpf/scripts/preprocess/genhpf/main.py,sha256=EF3sce0ltowMHIGK7zLEQEOnzOWQ_WJxoBowknHV3mQ,6161
|
|
47
47
|
genhpf/scripts/preprocess/genhpf/manifest.py,sha256=uHx0POSs9-ZB8Vtib7rPJ6hgDVJ1CBN6Ccfa4PpqmnM,2663
|
|
48
48
|
genhpf/scripts/preprocess/genhpf/sample_dataset.py,sha256=JzjMY2ynIYoWWtRlBG9Hxv6EoF27jJHyd3VYfqsM0Xs,5569
|
|
49
49
|
genhpf/scripts/preprocess/genhpf/ehrs/__init__.py,sha256=8bA4Pk0ylLIpwFQKEx6lis0k_inh4owF2SlHjHhKkeE,895
|
|
@@ -59,9 +59,9 @@ genhpf/utils/distributed_utils.py,sha256=000xKlw8SLoSH16o6n2bB3eueGR0aVD_DufPYES
|
|
|
59
59
|
genhpf/utils/file_io.py,sha256=hnZXdMtAibfFDoIfn-SDusl-v7ZImeUEh0eD2MIxbG4,4919
|
|
60
60
|
genhpf/utils/pdb.py,sha256=400rk1pVfOpVpzKIFHnTRlZ2VCtBqRh9G-pRRwu2Oqo,930
|
|
61
61
|
genhpf/utils/utils.py,sha256=BoC_7Gz8uCHbUBCpcXGBMD-5irApi_6xM7nU-2ac4aA,6176
|
|
62
|
-
genhpf-1.0.
|
|
63
|
-
genhpf-1.0.
|
|
64
|
-
genhpf-1.0.
|
|
65
|
-
genhpf-1.0.
|
|
66
|
-
genhpf-1.0.
|
|
67
|
-
genhpf-1.0.
|
|
62
|
+
genhpf-1.0.4.dist-info/LICENSE,sha256=VK_rvhY2Xi_DAIZHtauni5O9-1_do5SNWjrskv4amg8,1065
|
|
63
|
+
genhpf-1.0.4.dist-info/METADATA,sha256=Mgs4WysCKfBf4E2Jik2BgMdZPM8w1-rL5NoqgeqM5Zo,10589
|
|
64
|
+
genhpf-1.0.4.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
|
|
65
|
+
genhpf-1.0.4.dist-info/entry_points.txt,sha256=Wp94VV2w9KasBDLaluLM5EnjLgjNOAQVu44wKRDAwmQ,288
|
|
66
|
+
genhpf-1.0.4.dist-info/top_level.txt,sha256=lk846Vmnvydb6UZn8xmowj60nkrZYexNOGGnPM-IbhA,7
|
|
67
|
+
genhpf-1.0.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|