genhpf 1.0.6__py3-none-any.whl → 1.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of genhpf might be problematic. Click here for more details.
- genhpf/scripts/preprocess/preprocess_meds.py +28 -3
- genhpf/scripts/train.py +3 -0
- genhpf/trainer.py +2 -2
- {genhpf-1.0.6.dist-info → genhpf-1.0.8.dist-info}/METADATA +1 -1
- {genhpf-1.0.6.dist-info → genhpf-1.0.8.dist-info}/RECORD +9 -9
- {genhpf-1.0.6.dist-info → genhpf-1.0.8.dist-info}/WHEEL +1 -1
- {genhpf-1.0.6.dist-info → genhpf-1.0.8.dist-info}/LICENSE +0 -0
- {genhpf-1.0.6.dist-info → genhpf-1.0.8.dist-info}/entry_points.txt +0 -0
- {genhpf-1.0.6.dist-info → genhpf-1.0.8.dist-info}/top_level.txt +0 -0
|
@@ -129,6 +129,7 @@ def main():
|
|
|
129
129
|
num_workers = max(args.workers, 1)
|
|
130
130
|
if args.debug:
|
|
131
131
|
num_workers = 1
|
|
132
|
+
os.environ["RAYON_RS_NUM_CPUS"] = "1"
|
|
132
133
|
else:
|
|
133
134
|
cpu_count = multiprocessing.cpu_count()
|
|
134
135
|
if num_workers > cpu_count:
|
|
@@ -209,9 +210,16 @@ def main():
|
|
|
209
210
|
data_path = Path(data_path)
|
|
210
211
|
subdir = data_path.relative_to(root_path).parent
|
|
211
212
|
if data_path.suffix == ".csv":
|
|
212
|
-
data = pl.scan_csv(
|
|
213
|
+
data = pl.scan_csv(
|
|
214
|
+
data_path,
|
|
215
|
+
low_memory=True if args.debug else False,
|
|
216
|
+
)
|
|
213
217
|
elif data_path.suffix == ".parquet":
|
|
214
|
-
data = pl.scan_parquet(
|
|
218
|
+
data = pl.scan_parquet(
|
|
219
|
+
data_path,
|
|
220
|
+
parallel="none" if args.debug else "auto",
|
|
221
|
+
low_memory=True if args.debug else False,
|
|
222
|
+
)
|
|
215
223
|
else:
|
|
216
224
|
raise ValueError(f"Unsupported file format: {data_path.suffix}")
|
|
217
225
|
|
|
@@ -311,6 +319,9 @@ def main():
|
|
|
311
319
|
pl.col("time").list.sample(n=pl.col("code").list.len(), with_replacement=True)
|
|
312
320
|
)
|
|
313
321
|
|
|
322
|
+
if args.debug:
|
|
323
|
+
data = data[:5000]
|
|
324
|
+
|
|
314
325
|
if str(subdir) != ".":
|
|
315
326
|
output_name = str(subdir)
|
|
316
327
|
else:
|
|
@@ -347,6 +358,7 @@ def main():
|
|
|
347
358
|
d_labitems,
|
|
348
359
|
warned_codes,
|
|
349
360
|
max_event_length,
|
|
361
|
+
args.debug,
|
|
350
362
|
)
|
|
351
363
|
|
|
352
364
|
# meds --> remed
|
|
@@ -402,6 +414,7 @@ def meds_to_remed(
|
|
|
402
414
|
d_labitems,
|
|
403
415
|
warned_codes,
|
|
404
416
|
max_event_length,
|
|
417
|
+
debug,
|
|
405
418
|
df_chunk,
|
|
406
419
|
):
|
|
407
420
|
code_matching_pattern = re.compile(r"\d+")
|
|
@@ -590,6 +603,14 @@ def meds_to_remed(
|
|
|
590
603
|
maintain_order=True,
|
|
591
604
|
).agg(pl.all())
|
|
592
605
|
|
|
606
|
+
if debug:
|
|
607
|
+
df_chunk = df_chunk.with_columns(
|
|
608
|
+
[
|
|
609
|
+
pl.col("time").map_elements(lambda x: x[-100:], return_dtype=pl.List(pl.List(str))),
|
|
610
|
+
pl.col("data_index").map_elements(lambda x: x[-100:], return_dtype=pl.List(pl.List(int)))
|
|
611
|
+
]
|
|
612
|
+
)
|
|
613
|
+
|
|
593
614
|
df_chunk = df_chunk.sort(by=["subject_id", "cohort_end"])
|
|
594
615
|
# regard {subject_id} as {cohort_id}: {subject_id}_{cohort_number}
|
|
595
616
|
df_chunk = df_chunk.with_columns(pl.col("subject_id").cum_count().over("subject_id").alias("suffix"))
|
|
@@ -615,11 +636,15 @@ def meds_to_remed(
|
|
|
615
636
|
|
|
616
637
|
sample_result = result.create_group(sample[0])
|
|
617
638
|
|
|
639
|
+
times = np.concatenate(sample[2])
|
|
618
640
|
data_indices = np.concatenate(sample[3])
|
|
641
|
+
if debug:
|
|
642
|
+
data_indices = data_indices[-100:]
|
|
643
|
+
times = times[-100:]
|
|
644
|
+
|
|
619
645
|
data = events_data[data_indices]
|
|
620
646
|
sample_result.create_dataset("hi", data=data, dtype="i2", compression="lzf", shuffle=True)
|
|
621
647
|
|
|
622
|
-
times = np.concatenate(sample[2])
|
|
623
648
|
times = [datetime.strptime(x, "%Y-%m-%d %H:%M:%S") for x in times]
|
|
624
649
|
times = np.cumsum(np.diff(times))
|
|
625
650
|
times = list(map(lambda x: round(x.total_seconds() / 60), times))
|
genhpf/scripts/train.py
CHANGED
|
@@ -36,6 +36,9 @@ def main(cfg: Config) -> None:
|
|
|
36
36
|
# make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
|
|
37
37
|
logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg))
|
|
38
38
|
|
|
39
|
+
if cfg.common.debug:
|
|
40
|
+
os.environ["OMP_NUM_THREADS"] = "1"
|
|
41
|
+
|
|
39
42
|
assert cfg.dataset.batch_size is not None, "batch_size must be specified"
|
|
40
43
|
metrics.reset()
|
|
41
44
|
|
genhpf/trainer.py
CHANGED
|
@@ -202,7 +202,7 @@ class Trainer(object):
|
|
|
202
202
|
dataset,
|
|
203
203
|
batch_size=self.cfg.dataset.batch_size,
|
|
204
204
|
shuffle=True if not dist.is_initialized() else False,
|
|
205
|
-
num_workers=self.cfg.dataset.num_workers,
|
|
205
|
+
num_workers=self.cfg.dataset.num_workers if not self.cfg.common.debug else 0,
|
|
206
206
|
collate_fn=dataset.collator,
|
|
207
207
|
sampler=batch_sampler,
|
|
208
208
|
)
|
|
@@ -220,7 +220,7 @@ class Trainer(object):
|
|
|
220
220
|
dataset,
|
|
221
221
|
batch_size=self.cfg.dataset.batch_size,
|
|
222
222
|
shuffle=False,
|
|
223
|
-
num_workers=self.cfg.dataset.num_workers,
|
|
223
|
+
num_workers=self.cfg.dataset.num_workers if not self.cfg.common.debug else 0,
|
|
224
224
|
collate_fn=dataset.collator,
|
|
225
225
|
sampler=batch_sampler,
|
|
226
226
|
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: genhpf
|
|
3
|
-
Version: 1.0.
|
|
3
|
+
Version: 1.0.8
|
|
4
4
|
Summary: GenHPF: General Healthcare Predictive Framework with Multi-task Multi-source Learning
|
|
5
5
|
Author-email: Jungwoo Oh <ojw0123@kaist.ac.kr>, Kyunghoon Hur <pacesun@kaist.ac.kr>
|
|
6
6
|
License: MIT license
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
genhpf/__init__.py,sha256=uh6oTFMxEX_AwRqlfDmNeS3kU4QhY-KXG6nsQ2kjWNo,219
|
|
2
|
-
genhpf/trainer.py,sha256=
|
|
2
|
+
genhpf/trainer.py,sha256=v8wadlwI_HCopbCyEkaHw_abu2MscPibJjBWMg5pFw0,13339
|
|
3
3
|
genhpf/configs/__init__.py,sha256=L0heECTJaH5SyESeCWxbnpjAnJAIh8z05M8--DlQI8k,393
|
|
4
4
|
genhpf/configs/config.yaml,sha256=0Y8eL7b8lh3ZVSO8h7JhTPHi_CcPQ69zBv-2iTocjAg,63
|
|
5
5
|
genhpf/configs/configs.py,sha256=WpO_EzUoM32sKVtiVV4ynKrMGSt1Crdjf1C0Sc9Rhfg,10723
|
|
@@ -37,10 +37,10 @@ genhpf/modules/layer_norm.py,sha256=-aVKThi1pWvVMbMAzyQG1co6MHPBCUZgxWJKYzIqsPQ,
|
|
|
37
37
|
genhpf/modules/positional_encoding.py,sha256=Rf_qHdQArljEggRO4EHufc_JHq9-i44Oog1w9Bh51DQ,754
|
|
38
38
|
genhpf/scripts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
39
39
|
genhpf/scripts/test.py,sha256=DZPiZa-Tm6kKLcK3R1EH82gjq4Hbl098IAY4kA3fQxg,10288
|
|
40
|
-
genhpf/scripts/train.py,sha256
|
|
40
|
+
genhpf/scripts/train.py,sha256=-CY_OLRAX3wbthmH3fzkzSuZEEjHGKg0J4jzbbr9HoU,12942
|
|
41
41
|
genhpf/scripts/preprocess/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
42
42
|
genhpf/scripts/preprocess/manifest.py,sha256=ZIK16e4vs_cS2K_tM1GaT38hc1nBHk6JB9Uga6OjgU4,2711
|
|
43
|
-
genhpf/scripts/preprocess/preprocess_meds.py,sha256=
|
|
43
|
+
genhpf/scripts/preprocess/preprocess_meds.py,sha256=mch8Zl9Ht28fx7nsYfuFb0sc_PN6l1kBQ5iCeEEcrFw,25856
|
|
44
44
|
genhpf/scripts/preprocess/genhpf/README.md,sha256=qtpM_ABJk5yI8xbsUj1sZ71yX5bybx9ZvAymo0Lh5Vc,2877
|
|
45
45
|
genhpf/scripts/preprocess/genhpf/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
46
46
|
genhpf/scripts/preprocess/genhpf/main.py,sha256=EF3sce0ltowMHIGK7zLEQEOnzOWQ_WJxoBowknHV3mQ,6161
|
|
@@ -59,9 +59,9 @@ genhpf/utils/distributed_utils.py,sha256=000xKlw8SLoSH16o6n2bB3eueGR0aVD_DufPYES
|
|
|
59
59
|
genhpf/utils/file_io.py,sha256=hnZXdMtAibfFDoIfn-SDusl-v7ZImeUEh0eD2MIxbG4,4919
|
|
60
60
|
genhpf/utils/pdb.py,sha256=400rk1pVfOpVpzKIFHnTRlZ2VCtBqRh9G-pRRwu2Oqo,930
|
|
61
61
|
genhpf/utils/utils.py,sha256=BoC_7Gz8uCHbUBCpcXGBMD-5irApi_6xM7nU-2ac4aA,6176
|
|
62
|
-
genhpf-1.0.
|
|
63
|
-
genhpf-1.0.
|
|
64
|
-
genhpf-1.0.
|
|
65
|
-
genhpf-1.0.
|
|
66
|
-
genhpf-1.0.
|
|
67
|
-
genhpf-1.0.
|
|
62
|
+
genhpf-1.0.8.dist-info/LICENSE,sha256=VK_rvhY2Xi_DAIZHtauni5O9-1_do5SNWjrskv4amg8,1065
|
|
63
|
+
genhpf-1.0.8.dist-info/METADATA,sha256=k5-iE6UYfJ0rx_NJTuHVM4uw5IdhuJvztoORAtpc_6Q,10589
|
|
64
|
+
genhpf-1.0.8.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
|
|
65
|
+
genhpf-1.0.8.dist-info/entry_points.txt,sha256=Wp94VV2w9KasBDLaluLM5EnjLgjNOAQVu44wKRDAwmQ,288
|
|
66
|
+
genhpf-1.0.8.dist-info/top_level.txt,sha256=lk846Vmnvydb6UZn8xmowj60nkrZYexNOGGnPM-IbhA,7
|
|
67
|
+
genhpf-1.0.8.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|