genhpf 1.0.6__py3-none-any.whl → 1.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of genhpf might be problematic. Click here for more details.

@@ -129,6 +129,7 @@ def main():
129
129
  num_workers = max(args.workers, 1)
130
130
  if args.debug:
131
131
  num_workers = 1
132
+ os.environ["RAYON_RS_NUM_CPUS"] = "1"
132
133
  else:
133
134
  cpu_count = multiprocessing.cpu_count()
134
135
  if num_workers > cpu_count:
@@ -209,9 +210,16 @@ def main():
209
210
  data_path = Path(data_path)
210
211
  subdir = data_path.relative_to(root_path).parent
211
212
  if data_path.suffix == ".csv":
212
- data = pl.scan_csv(data_path)
213
+ data = pl.scan_csv(
214
+ data_path,
215
+ low_memory=True if args.debug else False,
216
+ )
213
217
  elif data_path.suffix == ".parquet":
214
- data = pl.scan_parquet(data_path)
218
+ data = pl.scan_parquet(
219
+ data_path,
220
+ parallel="none" if args.debug else "auto",
221
+ low_memory=True if args.debug else False,
222
+ )
215
223
  else:
216
224
  raise ValueError(f"Unsupported file format: {data_path.suffix}")
217
225
 
@@ -311,6 +319,9 @@ def main():
311
319
  pl.col("time").list.sample(n=pl.col("code").list.len(), with_replacement=True)
312
320
  )
313
321
 
322
+ if args.debug:
323
+ data = data[:5000]
324
+
314
325
  if str(subdir) != ".":
315
326
  output_name = str(subdir)
316
327
  else:
@@ -347,6 +358,7 @@ def main():
347
358
  d_labitems,
348
359
  warned_codes,
349
360
  max_event_length,
361
+ args.debug,
350
362
  )
351
363
 
352
364
  # meds --> remed
@@ -402,6 +414,7 @@ def meds_to_remed(
402
414
  d_labitems,
403
415
  warned_codes,
404
416
  max_event_length,
417
+ debug,
405
418
  df_chunk,
406
419
  ):
407
420
  code_matching_pattern = re.compile(r"\d+")
@@ -590,6 +603,14 @@ def meds_to_remed(
590
603
  maintain_order=True,
591
604
  ).agg(pl.all())
592
605
 
606
+ if debug:
607
+ df_chunk = df_chunk.with_columns(
608
+ [
609
+ pl.col("time").map_elements(lambda x: x[-100:], return_dtype=pl.List(pl.List(str))),
610
+ pl.col("data_index").map_elements(lambda x: x[-100:], return_dtype=pl.List(pl.List(int)))
611
+ ]
612
+ )
613
+
593
614
  df_chunk = df_chunk.sort(by=["subject_id", "cohort_end"])
594
615
  # regard {subject_id} as {cohort_id}: {subject_id}_{cohort_number}
595
616
  df_chunk = df_chunk.with_columns(pl.col("subject_id").cum_count().over("subject_id").alias("suffix"))
@@ -615,11 +636,15 @@ def meds_to_remed(
615
636
 
616
637
  sample_result = result.create_group(sample[0])
617
638
 
639
+ times = np.concatenate(sample[2])
618
640
  data_indices = np.concatenate(sample[3])
641
+ if debug:
642
+ data_indices = data_indices[-100:]
643
+ times = times[-100:]
644
+
619
645
  data = events_data[data_indices]
620
646
  sample_result.create_dataset("hi", data=data, dtype="i2", compression="lzf", shuffle=True)
621
647
 
622
- times = np.concatenate(sample[2])
623
648
  times = [datetime.strptime(x, "%Y-%m-%d %H:%M:%S") for x in times]
624
649
  times = np.cumsum(np.diff(times))
625
650
  times = list(map(lambda x: round(x.total_seconds() / 60), times))
genhpf/scripts/train.py CHANGED
@@ -36,6 +36,9 @@ def main(cfg: Config) -> None:
36
36
  # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
37
37
  logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg))
38
38
 
39
+ if cfg.common.debug:
40
+ os.environ["OMP_NUM_THREADS"] = "1"
41
+
39
42
  assert cfg.dataset.batch_size is not None, "batch_size must be specified"
40
43
  metrics.reset()
41
44
 
genhpf/trainer.py CHANGED
@@ -202,7 +202,7 @@ class Trainer(object):
202
202
  dataset,
203
203
  batch_size=self.cfg.dataset.batch_size,
204
204
  shuffle=True if not dist.is_initialized() else False,
205
- num_workers=self.cfg.dataset.num_workers,
205
+ num_workers=self.cfg.dataset.num_workers if not self.cfg.common.debug else 0,
206
206
  collate_fn=dataset.collator,
207
207
  sampler=batch_sampler,
208
208
  )
@@ -220,7 +220,7 @@ class Trainer(object):
220
220
  dataset,
221
221
  batch_size=self.cfg.dataset.batch_size,
222
222
  shuffle=False,
223
- num_workers=self.cfg.dataset.num_workers,
223
+ num_workers=self.cfg.dataset.num_workers if not self.cfg.common.debug else 0,
224
224
  collate_fn=dataset.collator,
225
225
  sampler=batch_sampler,
226
226
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: genhpf
3
- Version: 1.0.6
3
+ Version: 1.0.8
4
4
  Summary: GenHPF: General Healthcare Predictive Framework with Multi-task Multi-source Learning
5
5
  Author-email: Jungwoo Oh <ojw0123@kaist.ac.kr>, Kyunghoon Hur <pacesun@kaist.ac.kr>
6
6
  License: MIT license
@@ -1,5 +1,5 @@
1
1
  genhpf/__init__.py,sha256=uh6oTFMxEX_AwRqlfDmNeS3kU4QhY-KXG6nsQ2kjWNo,219
2
- genhpf/trainer.py,sha256=TXertjuaRNPVxvgrwI1PTJBHvRFYgNVeaZ65R1tFPmI,13267
2
+ genhpf/trainer.py,sha256=v8wadlwI_HCopbCyEkaHw_abu2MscPibJjBWMg5pFw0,13339
3
3
  genhpf/configs/__init__.py,sha256=L0heECTJaH5SyESeCWxbnpjAnJAIh8z05M8--DlQI8k,393
4
4
  genhpf/configs/config.yaml,sha256=0Y8eL7b8lh3ZVSO8h7JhTPHi_CcPQ69zBv-2iTocjAg,63
5
5
  genhpf/configs/configs.py,sha256=WpO_EzUoM32sKVtiVV4ynKrMGSt1Crdjf1C0Sc9Rhfg,10723
@@ -37,10 +37,10 @@ genhpf/modules/layer_norm.py,sha256=-aVKThi1pWvVMbMAzyQG1co6MHPBCUZgxWJKYzIqsPQ,
37
37
  genhpf/modules/positional_encoding.py,sha256=Rf_qHdQArljEggRO4EHufc_JHq9-i44Oog1w9Bh51DQ,754
38
38
  genhpf/scripts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
39
39
  genhpf/scripts/test.py,sha256=DZPiZa-Tm6kKLcK3R1EH82gjq4Hbl098IAY4kA3fQxg,10288
40
- genhpf/scripts/train.py,sha256=V4abCZ0r6qWxeJZdXyk4uXSVFXscqH_dhzk7CZuWrBA,12872
40
+ genhpf/scripts/train.py,sha256=-CY_OLRAX3wbthmH3fzkzSuZEEjHGKg0J4jzbbr9HoU,12942
41
41
  genhpf/scripts/preprocess/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
42
42
  genhpf/scripts/preprocess/manifest.py,sha256=ZIK16e4vs_cS2K_tM1GaT38hc1nBHk6JB9Uga6OjgU4,2711
43
- genhpf/scripts/preprocess/preprocess_meds.py,sha256=4GIK-_sQwB5A11FSbr_VnABY2MHxxNbFVmzSo71KpgQ,25074
43
+ genhpf/scripts/preprocess/preprocess_meds.py,sha256=mch8Zl9Ht28fx7nsYfuFb0sc_PN6l1kBQ5iCeEEcrFw,25856
44
44
  genhpf/scripts/preprocess/genhpf/README.md,sha256=qtpM_ABJk5yI8xbsUj1sZ71yX5bybx9ZvAymo0Lh5Vc,2877
45
45
  genhpf/scripts/preprocess/genhpf/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
46
46
  genhpf/scripts/preprocess/genhpf/main.py,sha256=EF3sce0ltowMHIGK7zLEQEOnzOWQ_WJxoBowknHV3mQ,6161
@@ -59,9 +59,9 @@ genhpf/utils/distributed_utils.py,sha256=000xKlw8SLoSH16o6n2bB3eueGR0aVD_DufPYES
59
59
  genhpf/utils/file_io.py,sha256=hnZXdMtAibfFDoIfn-SDusl-v7ZImeUEh0eD2MIxbG4,4919
60
60
  genhpf/utils/pdb.py,sha256=400rk1pVfOpVpzKIFHnTRlZ2VCtBqRh9G-pRRwu2Oqo,930
61
61
  genhpf/utils/utils.py,sha256=BoC_7Gz8uCHbUBCpcXGBMD-5irApi_6xM7nU-2ac4aA,6176
62
- genhpf-1.0.6.dist-info/LICENSE,sha256=VK_rvhY2Xi_DAIZHtauni5O9-1_do5SNWjrskv4amg8,1065
63
- genhpf-1.0.6.dist-info/METADATA,sha256=jNN97lqcfOLt3fbpmwd643IcTT-PlVof7IlkCur9zQs,10589
64
- genhpf-1.0.6.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
65
- genhpf-1.0.6.dist-info/entry_points.txt,sha256=Wp94VV2w9KasBDLaluLM5EnjLgjNOAQVu44wKRDAwmQ,288
66
- genhpf-1.0.6.dist-info/top_level.txt,sha256=lk846Vmnvydb6UZn8xmowj60nkrZYexNOGGnPM-IbhA,7
67
- genhpf-1.0.6.dist-info/RECORD,,
62
+ genhpf-1.0.8.dist-info/LICENSE,sha256=VK_rvhY2Xi_DAIZHtauni5O9-1_do5SNWjrskv4amg8,1065
63
+ genhpf-1.0.8.dist-info/METADATA,sha256=k5-iE6UYfJ0rx_NJTuHVM4uw5IdhuJvztoORAtpc_6Q,10589
64
+ genhpf-1.0.8.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
65
+ genhpf-1.0.8.dist-info/entry_points.txt,sha256=Wp94VV2w9KasBDLaluLM5EnjLgjNOAQVu44wKRDAwmQ,288
66
+ genhpf-1.0.8.dist-info/top_level.txt,sha256=lk846Vmnvydb6UZn8xmowj60nkrZYexNOGGnPM-IbhA,7
67
+ genhpf-1.0.8.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.2)
2
+ Generator: setuptools (76.0.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5