genhpf 1.0.7__py3-none-any.whl → 1.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of genhpf might be problematic. Click here for more details.
- genhpf/models/genhpf.py +11 -0
- genhpf/scripts/preprocess/preprocess_meds.py +27 -3
- genhpf/scripts/train.py +7 -0
- genhpf/trainer.py +2 -2
- {genhpf-1.0.7.dist-info → genhpf-1.0.9.dist-info}/METADATA +1 -1
- {genhpf-1.0.7.dist-info → genhpf-1.0.9.dist-info}/RECORD +10 -10
- {genhpf-1.0.7.dist-info → genhpf-1.0.9.dist-info}/LICENSE +0 -0
- {genhpf-1.0.7.dist-info → genhpf-1.0.9.dist-info}/WHEEL +0 -0
- {genhpf-1.0.7.dist-info → genhpf-1.0.9.dist-info}/entry_points.txt +0 -0
- {genhpf-1.0.7.dist-info → genhpf-1.0.9.dist-info}/top_level.txt +0 -0
genhpf/models/genhpf.py
CHANGED
|
@@ -73,6 +73,7 @@ class GenHPFConfig(BaseConfig):
|
|
|
73
73
|
)
|
|
74
74
|
|
|
75
75
|
vocab_size: int = II("dataset.vocab_size")
|
|
76
|
+
debug: bool = II("common.debug")
|
|
76
77
|
|
|
77
78
|
|
|
78
79
|
class GenHPF(nn.Module):
|
|
@@ -80,6 +81,16 @@ class GenHPF(nn.Module):
|
|
|
80
81
|
super().__init__()
|
|
81
82
|
self.cfg = cfg
|
|
82
83
|
|
|
84
|
+
if cfg.debug:
|
|
85
|
+
cfg.encoder_layers = 1
|
|
86
|
+
cfg.encoder_embed_dim = 32
|
|
87
|
+
cfg.encoder_ffn_embed_dim = 128
|
|
88
|
+
cfg.encoder_attention_heads = 2
|
|
89
|
+
cfg.agg_layers = 1
|
|
90
|
+
cfg.agg_embed_dim = 32
|
|
91
|
+
cfg.agg_ffn_embed_dim = 128
|
|
92
|
+
cfg.agg_attention_heads = 2
|
|
93
|
+
|
|
83
94
|
self.structure = cfg.structure
|
|
84
95
|
assert self.structure in GENHPF_MODEL_ARCH_CHOICES
|
|
85
96
|
|
|
@@ -210,9 +210,16 @@ def main():
|
|
|
210
210
|
data_path = Path(data_path)
|
|
211
211
|
subdir = data_path.relative_to(root_path).parent
|
|
212
212
|
if data_path.suffix == ".csv":
|
|
213
|
-
data = pl.scan_csv(
|
|
213
|
+
data = pl.scan_csv(
|
|
214
|
+
data_path,
|
|
215
|
+
low_memory=True if args.debug else False,
|
|
216
|
+
)
|
|
214
217
|
elif data_path.suffix == ".parquet":
|
|
215
|
-
data = pl.scan_parquet(
|
|
218
|
+
data = pl.scan_parquet(
|
|
219
|
+
data_path,
|
|
220
|
+
parallel="none" if args.debug else "auto",
|
|
221
|
+
low_memory=True if args.debug else False,
|
|
222
|
+
)
|
|
216
223
|
else:
|
|
217
224
|
raise ValueError(f"Unsupported file format: {data_path.suffix}")
|
|
218
225
|
|
|
@@ -312,6 +319,9 @@ def main():
|
|
|
312
319
|
pl.col("time").list.sample(n=pl.col("code").list.len(), with_replacement=True)
|
|
313
320
|
)
|
|
314
321
|
|
|
322
|
+
if args.debug:
|
|
323
|
+
data = data[:5000]
|
|
324
|
+
|
|
315
325
|
if str(subdir) != ".":
|
|
316
326
|
output_name = str(subdir)
|
|
317
327
|
else:
|
|
@@ -348,6 +358,7 @@ def main():
|
|
|
348
358
|
d_labitems,
|
|
349
359
|
warned_codes,
|
|
350
360
|
max_event_length,
|
|
361
|
+
args.debug,
|
|
351
362
|
)
|
|
352
363
|
|
|
353
364
|
# meds --> remed
|
|
@@ -403,6 +414,7 @@ def meds_to_remed(
|
|
|
403
414
|
d_labitems,
|
|
404
415
|
warned_codes,
|
|
405
416
|
max_event_length,
|
|
417
|
+
debug,
|
|
406
418
|
df_chunk,
|
|
407
419
|
):
|
|
408
420
|
code_matching_pattern = re.compile(r"\d+")
|
|
@@ -591,6 +603,14 @@ def meds_to_remed(
|
|
|
591
603
|
maintain_order=True,
|
|
592
604
|
).agg(pl.all())
|
|
593
605
|
|
|
606
|
+
if debug:
|
|
607
|
+
df_chunk = df_chunk.with_columns(
|
|
608
|
+
[
|
|
609
|
+
pl.col("time").map_elements(lambda x: x[-100:], return_dtype=pl.List(pl.List(str))),
|
|
610
|
+
pl.col("data_index").map_elements(lambda x: x[-100:], return_dtype=pl.List(pl.List(int)))
|
|
611
|
+
]
|
|
612
|
+
)
|
|
613
|
+
|
|
594
614
|
df_chunk = df_chunk.sort(by=["subject_id", "cohort_end"])
|
|
595
615
|
# regard {subject_id} as {cohort_id}: {subject_id}_{cohort_number}
|
|
596
616
|
df_chunk = df_chunk.with_columns(pl.col("subject_id").cum_count().over("subject_id").alias("suffix"))
|
|
@@ -616,11 +636,15 @@ def meds_to_remed(
|
|
|
616
636
|
|
|
617
637
|
sample_result = result.create_group(sample[0])
|
|
618
638
|
|
|
639
|
+
times = np.concatenate(sample[2])
|
|
619
640
|
data_indices = np.concatenate(sample[3])
|
|
641
|
+
if debug:
|
|
642
|
+
data_indices = data_indices[-100:]
|
|
643
|
+
times = times[-100:]
|
|
644
|
+
|
|
620
645
|
data = events_data[data_indices]
|
|
621
646
|
sample_result.create_dataset("hi", data=data, dtype="i2", compression="lzf", shuffle=True)
|
|
622
647
|
|
|
623
|
-
times = np.concatenate(sample[2])
|
|
624
648
|
times = [datetime.strptime(x, "%Y-%m-%d %H:%M:%S") for x in times]
|
|
625
649
|
times = np.cumsum(np.diff(times))
|
|
626
650
|
times = list(map(lambda x: round(x.total_seconds() / 60), times))
|
genhpf/scripts/train.py
CHANGED
|
@@ -36,6 +36,13 @@ def main(cfg: Config) -> None:
|
|
|
36
36
|
# make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
|
|
37
37
|
logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg))
|
|
38
38
|
|
|
39
|
+
if cfg.common.debug:
|
|
40
|
+
os.environ["OMP_NUM_THREADS"] = "4"
|
|
41
|
+
os.environ["MKL_NUM_THREADS"] = "4"
|
|
42
|
+
torch.set_num_threads(4)
|
|
43
|
+
torch.set_num_interop_threads(4)
|
|
44
|
+
cfg.optimization.max_epoch = 1
|
|
45
|
+
|
|
39
46
|
assert cfg.dataset.batch_size is not None, "batch_size must be specified"
|
|
40
47
|
metrics.reset()
|
|
41
48
|
|
genhpf/trainer.py
CHANGED
|
@@ -202,7 +202,7 @@ class Trainer(object):
|
|
|
202
202
|
dataset,
|
|
203
203
|
batch_size=self.cfg.dataset.batch_size,
|
|
204
204
|
shuffle=True if not dist.is_initialized() else False,
|
|
205
|
-
num_workers=self.cfg.dataset.num_workers,
|
|
205
|
+
num_workers=self.cfg.dataset.num_workers if not self.cfg.common.debug else 0,
|
|
206
206
|
collate_fn=dataset.collator,
|
|
207
207
|
sampler=batch_sampler,
|
|
208
208
|
)
|
|
@@ -220,7 +220,7 @@ class Trainer(object):
|
|
|
220
220
|
dataset,
|
|
221
221
|
batch_size=self.cfg.dataset.batch_size,
|
|
222
222
|
shuffle=False,
|
|
223
|
-
num_workers=self.cfg.dataset.num_workers,
|
|
223
|
+
num_workers=self.cfg.dataset.num_workers if not self.cfg.common.debug else 0,
|
|
224
224
|
collate_fn=dataset.collator,
|
|
225
225
|
sampler=batch_sampler,
|
|
226
226
|
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: genhpf
|
|
3
|
-
Version: 1.0.
|
|
3
|
+
Version: 1.0.9
|
|
4
4
|
Summary: GenHPF: General Healthcare Predictive Framework with Multi-task Multi-source Learning
|
|
5
5
|
Author-email: Jungwoo Oh <ojw0123@kaist.ac.kr>, Kyunghoon Hur <pacesun@kaist.ac.kr>
|
|
6
6
|
License: MIT license
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
genhpf/__init__.py,sha256=uh6oTFMxEX_AwRqlfDmNeS3kU4QhY-KXG6nsQ2kjWNo,219
|
|
2
|
-
genhpf/trainer.py,sha256=
|
|
2
|
+
genhpf/trainer.py,sha256=v8wadlwI_HCopbCyEkaHw_abu2MscPibJjBWMg5pFw0,13339
|
|
3
3
|
genhpf/configs/__init__.py,sha256=L0heECTJaH5SyESeCWxbnpjAnJAIh8z05M8--DlQI8k,393
|
|
4
4
|
genhpf/configs/config.yaml,sha256=0Y8eL7b8lh3ZVSO8h7JhTPHi_CcPQ69zBv-2iTocjAg,63
|
|
5
5
|
genhpf/configs/configs.py,sha256=WpO_EzUoM32sKVtiVV4ynKrMGSt1Crdjf1C0Sc9Rhfg,10723
|
|
@@ -23,7 +23,7 @@ genhpf/loggings/meters.py,sha256=ECdJTwFHx_4D22iNbv9VRxlh9iibX8aU9QeHPkqNmXQ,107
|
|
|
23
23
|
genhpf/loggings/metrics.py,sha256=3CSBA5C3bd-G-zNer7BeOqSZj-tn6twbpLqAlt-FQ_A,3935
|
|
24
24
|
genhpf/loggings/progress_bar.py,sha256=9-24WAFDsp6WSS-JncnQtQMwo7DnNEYakAt7a8pkhF0,14140
|
|
25
25
|
genhpf/models/__init__.py,sha256=EG4YnL8Uiem8iUNm72euHJlim0IZj3inzFVFCFOvPCE,2223
|
|
26
|
-
genhpf/models/genhpf.py,sha256=
|
|
26
|
+
genhpf/models/genhpf.py,sha256=Y9f8H3fgUm1H-QWTnRzcQMu1Pkl6i0ZNNRuSmZZ6Zh0,9712
|
|
27
27
|
genhpf/models/genhpf_mlm.py,sha256=rExPpm1HDjljAjgFbYx2bgS6VSaIKF6-P7VJcq6YLB0,1882
|
|
28
28
|
genhpf/models/genhpf_predictor.py,sha256=i-XIh7S3ozpB_r4JZI27sfdnbANyQYpBIOrDDgsiWvc,2163
|
|
29
29
|
genhpf/models/genhpf_simclr.py,sha256=Iuqx0fy0AQurkTk0e5hEv12eJyeGGGiQJiRKXGgOTnI,1629
|
|
@@ -37,10 +37,10 @@ genhpf/modules/layer_norm.py,sha256=-aVKThi1pWvVMbMAzyQG1co6MHPBCUZgxWJKYzIqsPQ,
|
|
|
37
37
|
genhpf/modules/positional_encoding.py,sha256=Rf_qHdQArljEggRO4EHufc_JHq9-i44Oog1w9Bh51DQ,754
|
|
38
38
|
genhpf/scripts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
39
39
|
genhpf/scripts/test.py,sha256=DZPiZa-Tm6kKLcK3R1EH82gjq4Hbl098IAY4kA3fQxg,10288
|
|
40
|
-
genhpf/scripts/train.py,sha256=
|
|
40
|
+
genhpf/scripts/train.py,sha256=juUgfSVLAXhtBPzIEG09W5lkLlKIv2GHIbMn7IgBJjc,13099
|
|
41
41
|
genhpf/scripts/preprocess/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
42
42
|
genhpf/scripts/preprocess/manifest.py,sha256=ZIK16e4vs_cS2K_tM1GaT38hc1nBHk6JB9Uga6OjgU4,2711
|
|
43
|
-
genhpf/scripts/preprocess/preprocess_meds.py,sha256=
|
|
43
|
+
genhpf/scripts/preprocess/preprocess_meds.py,sha256=mch8Zl9Ht28fx7nsYfuFb0sc_PN6l1kBQ5iCeEEcrFw,25856
|
|
44
44
|
genhpf/scripts/preprocess/genhpf/README.md,sha256=qtpM_ABJk5yI8xbsUj1sZ71yX5bybx9ZvAymo0Lh5Vc,2877
|
|
45
45
|
genhpf/scripts/preprocess/genhpf/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
46
46
|
genhpf/scripts/preprocess/genhpf/main.py,sha256=EF3sce0ltowMHIGK7zLEQEOnzOWQ_WJxoBowknHV3mQ,6161
|
|
@@ -59,9 +59,9 @@ genhpf/utils/distributed_utils.py,sha256=000xKlw8SLoSH16o6n2bB3eueGR0aVD_DufPYES
|
|
|
59
59
|
genhpf/utils/file_io.py,sha256=hnZXdMtAibfFDoIfn-SDusl-v7ZImeUEh0eD2MIxbG4,4919
|
|
60
60
|
genhpf/utils/pdb.py,sha256=400rk1pVfOpVpzKIFHnTRlZ2VCtBqRh9G-pRRwu2Oqo,930
|
|
61
61
|
genhpf/utils/utils.py,sha256=BoC_7Gz8uCHbUBCpcXGBMD-5irApi_6xM7nU-2ac4aA,6176
|
|
62
|
-
genhpf-1.0.
|
|
63
|
-
genhpf-1.0.
|
|
64
|
-
genhpf-1.0.
|
|
65
|
-
genhpf-1.0.
|
|
66
|
-
genhpf-1.0.
|
|
67
|
-
genhpf-1.0.
|
|
62
|
+
genhpf-1.0.9.dist-info/LICENSE,sha256=VK_rvhY2Xi_DAIZHtauni5O9-1_do5SNWjrskv4amg8,1065
|
|
63
|
+
genhpf-1.0.9.dist-info/METADATA,sha256=0YRTk9CjFLdEVayQOm7mvdDUi1oBVTLv-v-GANBbuaY,10589
|
|
64
|
+
genhpf-1.0.9.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
|
|
65
|
+
genhpf-1.0.9.dist-info/entry_points.txt,sha256=Wp94VV2w9KasBDLaluLM5EnjLgjNOAQVu44wKRDAwmQ,288
|
|
66
|
+
genhpf-1.0.9.dist-info/top_level.txt,sha256=lk846Vmnvydb6UZn8xmowj60nkrZYexNOGGnPM-IbhA,7
|
|
67
|
+
genhpf-1.0.9.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|