genhpf 1.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. genhpf/__init__.py +9 -0
  2. genhpf/configs/__init__.py +23 -0
  3. genhpf/configs/config.yaml +8 -0
  4. genhpf/configs/configs.py +240 -0
  5. genhpf/configs/constants.py +29 -0
  6. genhpf/configs/initialize.py +58 -0
  7. genhpf/configs/utils.py +29 -0
  8. genhpf/criterions/__init__.py +74 -0
  9. genhpf/criterions/binary_cross_entropy.py +114 -0
  10. genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
  11. genhpf/criterions/criterion.py +87 -0
  12. genhpf/criterions/cross_entropy.py +202 -0
  13. genhpf/criterions/multi_task_criterion.py +177 -0
  14. genhpf/criterions/simclr_criterion.py +84 -0
  15. genhpf/criterions/wav2vec2_criterion.py +130 -0
  16. genhpf/datasets/__init__.py +84 -0
  17. genhpf/datasets/dataset.py +109 -0
  18. genhpf/datasets/genhpf_dataset.py +451 -0
  19. genhpf/datasets/meds_dataset.py +232 -0
  20. genhpf/loggings/__init__.py +0 -0
  21. genhpf/loggings/meters.py +374 -0
  22. genhpf/loggings/metrics.py +155 -0
  23. genhpf/loggings/progress_bar.py +445 -0
  24. genhpf/models/__init__.py +73 -0
  25. genhpf/models/genhpf.py +244 -0
  26. genhpf/models/genhpf_mlm.py +64 -0
  27. genhpf/models/genhpf_predictor.py +73 -0
  28. genhpf/models/genhpf_simclr.py +58 -0
  29. genhpf/models/genhpf_wav2vec2.py +304 -0
  30. genhpf/modules/__init__.py +15 -0
  31. genhpf/modules/gather_layer.py +23 -0
  32. genhpf/modules/grad_multiply.py +12 -0
  33. genhpf/modules/gumbel_vector_quantizer.py +204 -0
  34. genhpf/modules/identity_layer.py +8 -0
  35. genhpf/modules/layer_norm.py +27 -0
  36. genhpf/modules/positional_encoding.py +24 -0
  37. genhpf/scripts/__init__.py +0 -0
  38. genhpf/scripts/preprocess/__init__.py +0 -0
  39. genhpf/scripts/preprocess/genhpf/README.md +75 -0
  40. genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
  41. genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
  42. genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
  43. genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
  44. genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
  45. genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
  46. genhpf/scripts/preprocess/genhpf/main.py +175 -0
  47. genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
  48. genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
  49. genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
  50. genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
  51. genhpf/scripts/preprocess/manifest.py +83 -0
  52. genhpf/scripts/preprocess/preprocess_meds.py +674 -0
  53. genhpf/scripts/test.py +264 -0
  54. genhpf/scripts/train.py +365 -0
  55. genhpf/trainer.py +370 -0
  56. genhpf/utils/checkpoint_utils.py +171 -0
  57. genhpf/utils/data_utils.py +130 -0
  58. genhpf/utils/distributed_utils.py +497 -0
  59. genhpf/utils/file_io.py +170 -0
  60. genhpf/utils/pdb.py +38 -0
  61. genhpf/utils/utils.py +204 -0
  62. genhpf-1.0.11.dist-info/LICENSE +21 -0
  63. genhpf-1.0.11.dist-info/METADATA +202 -0
  64. genhpf-1.0.11.dist-info/RECORD +67 -0
  65. genhpf-1.0.11.dist-info/WHEEL +5 -0
  66. genhpf-1.0.11.dist-info/entry_points.txt +6 -0
  67. genhpf-1.0.11.dist-info/top_level.txt +1 -0
@@ -0,0 +1,445 @@
1
+ """
2
+ Wrapper around various loggers and progress bars (e.g., tqdm).
3
+ """
4
+
5
+ import json
6
+ import logging
7
+ import os
8
+ import sys
9
+ from collections import OrderedDict
10
+ from contextlib import contextmanager
11
+ from numbers import Number
12
+ from typing import Optional
13
+
14
+ import torch
15
+
16
+ from .meters import AverageMeter, StopwatchMeter, TimeMeter
17
+
18
+ from genhpf.utils import distributed_utils as dist_utils
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ def progress_bar(
24
+ iterator,
25
+ log_format: Optional[str] = None,
26
+ log_interval: int = 100,
27
+ log_file: Optional[str] = None,
28
+ epoch: Optional[int] = None,
29
+ prefix: Optional[str] = None,
30
+ default_log_format: str = "tqdm",
31
+ wandb_project: Optional[str] = None,
32
+ wandb_entity: Optional[str] = None,
33
+ wandb_run_name: Optional[str] = None,
34
+ ):
35
+ if log_format is None:
36
+ log_format = default_log_format
37
+ if log_file is not None:
38
+ handler = logging.FileHandler(filename=log_file)
39
+ logger.addHandler(handler)
40
+
41
+ if log_format == "tqdm" and not sys.stderr.isatty():
42
+ log_format = "simple"
43
+
44
+ if log_format == "json":
45
+ bar = JsonProgressBar(iterator, epoch, prefix, log_interval)
46
+ elif log_format == "none":
47
+ bar = NoopProgressBar(iterator, epoch, prefix)
48
+ elif log_format == "simple":
49
+ bar = SimpleProgressBar(iterator, epoch, prefix, log_interval)
50
+ elif log_format == "tqdm":
51
+ bar = TqdmProgressBar(iterator, epoch, prefix)
52
+ elif log_format == "csv":
53
+ bar = CsvProgressBar(iterator, epoch, prefix, log_interval)
54
+ else:
55
+ raise ValueError("Unknown log format: {}".format(log_format))
56
+
57
+ if wandb_project:
58
+ bar = WandBProgressBarWrapper(bar, wandb_project, wandb_entity, run_name=wandb_run_name)
59
+
60
+ return bar
61
+
62
+
63
+ def build_progress_bar(
64
+ args,
65
+ iterator,
66
+ epoch: Optional[int] = None,
67
+ prefix: Optional[str] = None,
68
+ default: str = "tqdm",
69
+ no_progress_bar: str = "none",
70
+ ):
71
+ """Legacy wrapper that takes an argparse.Namespace."""
72
+ if getattr(args, "no_progress_bar", False):
73
+ default = no_progress_bar
74
+ tensorboard_logdir = None
75
+ return progress_bar(
76
+ iterator,
77
+ log_format=args.log_format,
78
+ log_interval=args.log_interval,
79
+ epoch=epoch,
80
+ prefix=prefix,
81
+ tensorboard_logdir=tensorboard_logdir,
82
+ default_log_format=default,
83
+ )
84
+
85
+
86
+ def format_stat(stat):
87
+ if isinstance(stat, Number):
88
+ stat = "{:g}".format(stat)
89
+ elif isinstance(stat, AverageMeter):
90
+ stat = "{:.3f}".format(stat.avg)
91
+ elif isinstance(stat, TimeMeter):
92
+ stat = "{:g}".format(round(stat.avg))
93
+ elif isinstance(stat, StopwatchMeter):
94
+ stat = "{:g}".format(round(stat.sum))
95
+ elif torch.is_tensor(stat):
96
+ stat = stat.tolist()
97
+ return stat
98
+
99
+
100
+ class BaseProgressBar(object):
101
+ """Abstract class for progress bars."""
102
+
103
+ def __init__(self, iterable, epoch=None, prefix=None):
104
+ self.iterable = iterable
105
+ self.n = getattr(iterable, "n", 0)
106
+ self.epoch = epoch
107
+ self.prefix = ""
108
+ if epoch is not None:
109
+ self.prefix += "epoch {:03d}".format(epoch)
110
+ if prefix is not None:
111
+ self.prefix += (" | " if self.prefix != "" else "") + prefix
112
+
113
+ def __len__(self):
114
+ return len(self.iterable)
115
+
116
+ def __enter__(self):
117
+ return self
118
+
119
+ def __exit__(self, *exc):
120
+ return False
121
+
122
+ def __iter__(self):
123
+ raise NotImplementedError
124
+
125
+ def log(self, stats, tag=None, step=None):
126
+ """Log intermediate stats according to log_interval."""
127
+ raise NotImplementedError
128
+
129
+ def print(self, stats, tag=None, step=None):
130
+ """Print end-of-epoch stats."""
131
+ raise NotImplementedError
132
+
133
+ def update_config(self, config):
134
+ """Log latest configuration."""
135
+ pass
136
+
137
+ def _str_commas(self, stats):
138
+ return ", ".join(key + "=" + stats[key].strip() for key in stats.keys())
139
+
140
+ def _str_pipes(self, stats):
141
+ return " | ".join(key + " " + stats[key].strip() for key in stats.keys())
142
+
143
+ def _format_stats(self, stats):
144
+ postfix = OrderedDict(stats)
145
+ # Preprocess stats according to datatype
146
+ for key in postfix.keys():
147
+ postfix[key] = str(format_stat(postfix[key]))
148
+ return postfix
149
+
150
+
151
+ @contextmanager
152
+ def rename_logger(logger, new_name):
153
+ old_name = logger.name
154
+ if new_name is not None:
155
+ logger.name = new_name
156
+ yield logger
157
+ logger.name = old_name
158
+
159
+
160
+ class JsonProgressBar(BaseProgressBar):
161
+ """Log output in JSON format."""
162
+
163
+ def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000):
164
+ super().__init__(iterable, epoch, prefix)
165
+ self.log_interval = log_interval
166
+ self.i = None
167
+ self.size = None
168
+
169
+ def __iter__(self):
170
+ self.size = len(self.iterable)
171
+ for i, obj in enumerate(self.iterable, start=self.n):
172
+ self.i = i
173
+ yield obj
174
+
175
+ def log(self, stats, tag=None, step=None):
176
+ """Log intermediate stats according to log_interval."""
177
+ step = step or self.i or 0
178
+
179
+ if step > 0 and self.log_interval is not None and step % self.log_interval == 0:
180
+ update = (
181
+ self.epoch - 1 + (self.i + 1) / float(self.size)
182
+ if self.epoch is not None
183
+ else None
184
+ )
185
+ stats = self._format_stats(stats, epoch=self.epoch, update=update)
186
+ with rename_logger(logger, tag):
187
+ logger.info(json.dumps(stats))
188
+
189
+ def print(self, stats, tag=None, step=None):
190
+ """Print end-of-epoch stats."""
191
+ self.stats = stats
192
+ if tag is not None:
193
+ self.stats = OrderedDict(
194
+ [(tag + "_" + k, v) for k, v in self.stats.items()]
195
+ )
196
+ stats = self._format_stats(self.stats, epoch=self.epoch)
197
+ with rename_logger(logger, tag):
198
+ logger.info(json.dumps(stats))
199
+
200
+ def _format_stats(self, stats, epoch=None, update=None):
201
+ postfix = OrderedDict()
202
+ if epoch is not None:
203
+ postfix["epoch"] = epoch
204
+ if update is not None:
205
+ postfix["update"] = round(update, 3)
206
+ # Preprocess stats according to datatype
207
+ for key in stats.keys():
208
+ if not key.startswith("_"):
209
+ postfix[key] = format_stat(stats[key])
210
+ return postfix
211
+
212
+ class NoopProgressBar(BaseProgressBar):
213
+ """No logging."""
214
+
215
+ def __init__(self, iterable, epoch=None, prefix=None):
216
+ super().__init__(iterable, epoch, prefix)
217
+
218
+ def __iter__(self):
219
+ for obj in self.iterable:
220
+ yield obj
221
+
222
+ def log(self, stats, tag=None, step=None):
223
+ """Log intermediate stats according to log_interval."""
224
+ pass
225
+
226
+ def print(self, stats, tag=None, step=None):
227
+ """Print end-of-epoch stats."""
228
+ pass
229
+
230
+
231
+ class SimpleProgressBar(BaseProgressBar):
232
+ """A minimal logger for non-TTY environments."""
233
+
234
+ def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000):
235
+ super().__init__(iterable, epoch, prefix)
236
+ self.log_interval = log_interval
237
+ self.i = None
238
+ self.size = None
239
+
240
+ def __iter__(self):
241
+ self.size = len(self.iterable)
242
+ for i, obj in enumerate(self.iterable, start=self.n):
243
+ self.i = i
244
+ yield obj
245
+
246
+ def log(self, stats, tag=None, step=None):
247
+ """Log intermediate stats according to log_interval."""
248
+ step = step or self.i or 0
249
+ if step > 0 and self.log_interval is not None and step % self.log_interval == 0:
250
+ stats = self._format_stats(stats)
251
+ postfix = self._str_commas(stats)
252
+ with rename_logger(logger, tag):
253
+ logger.info(
254
+ "{}: {:5d} / {:d} {}".format(
255
+ self.prefix, self.i + 1, self.size, postfix
256
+ )
257
+ )
258
+
259
+ def print(self, stats, tag=None, step=None):
260
+ """Print end-of-epoch stats."""
261
+ postfix = self._str_pipes(self._format_stats(stats))
262
+ with rename_logger(logger, tag):
263
+ logger.info("{} | {}".format(self.prefix, postfix))
264
+
265
+
266
+ class TqdmProgressBar(BaseProgressBar):
267
+ """Log to tqdm."""
268
+
269
+ def __init__(self, iterable, epoch=None, prefix=None):
270
+ super().__init__(iterable, epoch, prefix)
271
+ from tqdm import tqdm
272
+
273
+ self.tqdm = tqdm(
274
+ iterable,
275
+ self.prefix,
276
+ leave=False,
277
+ disable=(logger.getEffectiveLevel() > logging.INFO),
278
+ )
279
+
280
+ def __iter__(self):
281
+ return iter(self.tqdm)
282
+
283
+ def log(self, stats, tag=None, step=None):
284
+ """Log intermediate stats according to log_interval."""
285
+ self.tqdm.set_postfix(self._format_stats(stats), refresh=False)
286
+
287
+ def print(self, stats, tag=None, step=None):
288
+ """Print end-of-epoch stats."""
289
+ postfix = self._str_pipes(self._format_stats(stats))
290
+ with rename_logger(logger, tag):
291
+ logger.info("{} | {}".format(self.prefix, postfix))
292
+
293
+ try:
294
+ import csv
295
+ except ImportError:
296
+ csv = None
297
+
298
+ class CsvProgressBar(BaseProgressBar):
299
+ """Log to csv."""
300
+
301
+ def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000):
302
+ super().__init__(iterable, epoch, prefix)
303
+ self.log_interval = log_interval
304
+ self.i = None
305
+ self.size = None
306
+
307
+ def __iter__(self):
308
+ self.size = len(self.iterable)
309
+ for i, obj in enumerate(self.iterable, start=self.n):
310
+ self.i = i
311
+ yield obj
312
+
313
+ def log(self, stats, tag=None, step=None):
314
+ """Log intermediate stats according to csv."""
315
+ self._log_to_csv(stats, tag, step)
316
+ self._print_log(stats, tag, step)
317
+
318
+ def print(self, stats, tag=None, step=None):
319
+ """Log end-of-epoch stats."""
320
+ self._log_to_csv(stats, tag, step)
321
+ self._print_log(stats, tag, step)
322
+
323
+ def _print_log(self, stats, tag=None, step=None):
324
+ """Print intermediate stats."""
325
+ step = step or self.i or 0
326
+
327
+ if step > 0 and self.log_interval is not None and step % self.log_interval == 0:
328
+ stats = self._format_stats(stats)
329
+ postfix = self._str_commas(stats)
330
+ with rename_logger(logger, tag):
331
+ logger.info(
332
+ "{}: {:5d} / {:d} {}".format(
333
+ self.prefix, self.i + 1, self.size, postfix
334
+ )
335
+ )
336
+
337
+ def _log_to_csv(self, stats, tag=None, step=None):
338
+ if csv is None:
339
+ return
340
+ if dist_utils.get_data_parallel_world_size() > 1 and dist_utils.get_data_parallel_rank() > 0:
341
+ return
342
+
343
+ csv_logs = {}
344
+
345
+ if step is None:
346
+ csv_logs["step"] = stats["num_updates"] if "num_updates" in stats else None
347
+ else:
348
+ csv_logs["step"] = step
349
+
350
+ for key in stats.keys() - {"num_updates"}:
351
+ if isinstance(stats[key], AverageMeter):
352
+ csv_logs[key] = stats[key].val
353
+ elif isinstance(stats[key], Number):
354
+ csv_logs[key] = stats[key]
355
+
356
+ fname = "log.csv" if tag is None else tag + ".csv"
357
+ if not os.path.exists(fname):
358
+ cols = ["step"]
359
+ for key in csv_logs.keys() - {"step"}:
360
+ cols.append(key)
361
+ with open(fname, "w") as f:
362
+ wr = csv.writer(f)
363
+ wr.writerow(cols)
364
+
365
+ with open(fname, "r") as f:
366
+ rd = csv.reader(f)
367
+ headers = next(iter(rd))
368
+
369
+ appended_lth = 0
370
+ for key in csv_logs.keys():
371
+ if key not in headers:
372
+ headers.append(key)
373
+ appended_lth += 1
374
+
375
+ if appended_lth > 0:
376
+ lines = [headers]
377
+ with open(fname, "r") as f:
378
+ rd = csv.reader(f)
379
+ # drop headers
380
+ next(iter(rd))
381
+ for line in rd:
382
+ for _ in range(appended_lth):
383
+ line.append("")
384
+ lines.append(line)
385
+ with open(fname, "w") as f:
386
+ wr = csv.writer(f)
387
+ wr.writerows(lines)
388
+
389
+ with open(fname, "a") as f:
390
+ wr = csv.DictWriter(f, fieldnames=headers)
391
+ wr.writerow(csv_logs)
392
+
393
+ try:
394
+ import wandb
395
+ except ImportError:
396
+ wandb = None
397
+
398
+ class WandBProgressBarWrapper(BaseProgressBar):
399
+ """Log to Weights & Biases."""
400
+
401
+ def __init__(self, wrapped_bar, wandb_project, wandb_entity, run_name=None):
402
+ self.wrapped_bar = wrapped_bar
403
+ if wandb is None:
404
+ logger.warning("wandb not found, pip install wandb")
405
+ return
406
+
407
+ # reinit=False to ensure if wandb.init() is called multiple times
408
+ # within one process it still references the same run
409
+ wandb.init(project=wandb_project, entity=wandb_entity, reinit=False, name=run_name)
410
+
411
+ def __iter__(self):
412
+ return iter(self.wrapped_bar)
413
+
414
+ def log(self, stats, tag=None, step=None):
415
+ """Log intermediate stats to wandb."""
416
+ self._log_to_wandb(stats, tag, step)
417
+ self.wrapped_bar.log(stats, tag=tag, step=step)
418
+
419
+ def print(self, stats, tag=None, step=None):
420
+ """Print end-of-epoch stats."""
421
+ self._log_to_wandb(stats, tag, step)
422
+ self.wrapped_bar.print(stats, tag=tag, step=step)
423
+
424
+ def update_config(self, config):
425
+ """Log latest configuration."""
426
+ if wandb is not None:
427
+ wandb.config.update(config)
428
+ self.wrapped_bar.update_config(config)
429
+
430
+ def _log_to_wandb(self, stats, tag=None, step=None):
431
+ if wandb is None:
432
+ return
433
+ if step is None:
434
+ step = stats["num_updates"] if "num_updates" in stats else None
435
+
436
+ prefix = "" if tag is None else tag + "/"
437
+
438
+ wandb_logs = {}
439
+ for key in stats.keys() - {"num_updates"}:
440
+ if isinstance(stats[key], AverageMeter):
441
+ wandb_logs[prefix + key] = stats[key].val
442
+ elif isinstance(stats[key], Number):
443
+ wandb_logs[prefix + key] = stats[key]
444
+
445
+ wandb.log(wandb_logs, step=step)
@@ -0,0 +1,73 @@
1
+ import importlib
2
+ import os
3
+
4
+ from hydra.core.config_store import ConfigStore
5
+
6
+ from .genhpf import GenHPF #noqa
7
+ from genhpf.configs.utils import merge_with_parent
8
+
9
+ MODEL_REGISTRY = {}
10
+ MODEL_DATACLASS_REGISTRY = {}
11
+
12
+ def build_model(cfg):
13
+ model = None
14
+ model_type = getattr(cfg, "_name", None)
15
+
16
+ if model_type in MODEL_REGISTRY:
17
+ model = MODEL_REGISTRY[model_type]
18
+ # set defaults from dataclass
19
+ dc = MODEL_DATACLASS_REGISTRY[model_type]
20
+ cfg = merge_with_parent(dc(), cfg)
21
+
22
+ assert model is not None, (
23
+ f"Could not infer model type from {str(model_type)}. "
24
+ + "Available models: "
25
+ + str(MODEL_REGISTRY.keys())
26
+ + " Requested model type: "
27
+ + str(model_type)
28
+ )
29
+
30
+ model_instance = model.build_model(cfg)
31
+ return model_instance
32
+
33
+ def register_model(name, dataclass=None):
34
+ """
35
+ New model types can be added with the :func:`register_model`
36
+ function decorator.
37
+
38
+ Args:
39
+ name (str): the name of the model
40
+ """
41
+
42
+ def register_model_cls(cls):
43
+ if name in MODEL_REGISTRY:
44
+ raise ValueError(f"Cannot register duplicate model ({name})")
45
+ if not issubclass(cls, GenHPF):
46
+ raise ValueError(f"Model ({name}: {cls.__name__}) must extend GenHPF")
47
+ MODEL_REGISTRY[name] = cls
48
+ if dataclass is not None:
49
+ MODEL_DATACLASS_REGISTRY[name] = dataclass
50
+
51
+ cs = ConfigStore.instance()
52
+ node = dataclass()
53
+ node._name = name
54
+ cs.store(name=name, group="model", node=node, provider="genhpf")
55
+
56
+ return cls
57
+
58
+ return register_model_cls
59
+
60
+ def import_models(models_dir, namespace):
61
+ for file in os.listdir(models_dir):
62
+ path = os.path.join(models_dir, file)
63
+ if (
64
+ not file.startswith("_")
65
+ and not file.startswith(".")
66
+ and (file.endswith(".py") or os.path.isdir(path))
67
+ ):
68
+ model_name = file[: file.find(".py")] if file.endswith(".py") else file
69
+ importlib.import_module(namespace + "." + model_name)
70
+
71
+ # automatically import any Python files in the models/ directory
72
+ models_dir = os.path.dirname(__file__)
73
+ import_models(models_dir, "genhpf.models")