genhpf 1.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. genhpf/__init__.py +9 -0
  2. genhpf/configs/__init__.py +23 -0
  3. genhpf/configs/config.yaml +8 -0
  4. genhpf/configs/configs.py +240 -0
  5. genhpf/configs/constants.py +29 -0
  6. genhpf/configs/initialize.py +58 -0
  7. genhpf/configs/utils.py +29 -0
  8. genhpf/criterions/__init__.py +74 -0
  9. genhpf/criterions/binary_cross_entropy.py +114 -0
  10. genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
  11. genhpf/criterions/criterion.py +87 -0
  12. genhpf/criterions/cross_entropy.py +202 -0
  13. genhpf/criterions/multi_task_criterion.py +177 -0
  14. genhpf/criterions/simclr_criterion.py +84 -0
  15. genhpf/criterions/wav2vec2_criterion.py +130 -0
  16. genhpf/datasets/__init__.py +84 -0
  17. genhpf/datasets/dataset.py +109 -0
  18. genhpf/datasets/genhpf_dataset.py +451 -0
  19. genhpf/datasets/meds_dataset.py +232 -0
  20. genhpf/loggings/__init__.py +0 -0
  21. genhpf/loggings/meters.py +374 -0
  22. genhpf/loggings/metrics.py +155 -0
  23. genhpf/loggings/progress_bar.py +445 -0
  24. genhpf/models/__init__.py +73 -0
  25. genhpf/models/genhpf.py +244 -0
  26. genhpf/models/genhpf_mlm.py +64 -0
  27. genhpf/models/genhpf_predictor.py +73 -0
  28. genhpf/models/genhpf_simclr.py +58 -0
  29. genhpf/models/genhpf_wav2vec2.py +304 -0
  30. genhpf/modules/__init__.py +15 -0
  31. genhpf/modules/gather_layer.py +23 -0
  32. genhpf/modules/grad_multiply.py +12 -0
  33. genhpf/modules/gumbel_vector_quantizer.py +204 -0
  34. genhpf/modules/identity_layer.py +8 -0
  35. genhpf/modules/layer_norm.py +27 -0
  36. genhpf/modules/positional_encoding.py +24 -0
  37. genhpf/scripts/__init__.py +0 -0
  38. genhpf/scripts/preprocess/__init__.py +0 -0
  39. genhpf/scripts/preprocess/genhpf/README.md +75 -0
  40. genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
  41. genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
  42. genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
  43. genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
  44. genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
  45. genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
  46. genhpf/scripts/preprocess/genhpf/main.py +175 -0
  47. genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
  48. genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
  49. genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
  50. genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
  51. genhpf/scripts/preprocess/manifest.py +83 -0
  52. genhpf/scripts/preprocess/preprocess_meds.py +674 -0
  53. genhpf/scripts/test.py +264 -0
  54. genhpf/scripts/train.py +365 -0
  55. genhpf/trainer.py +370 -0
  56. genhpf/utils/checkpoint_utils.py +171 -0
  57. genhpf/utils/data_utils.py +130 -0
  58. genhpf/utils/distributed_utils.py +497 -0
  59. genhpf/utils/file_io.py +170 -0
  60. genhpf/utils/pdb.py +38 -0
  61. genhpf/utils/utils.py +204 -0
  62. genhpf-1.0.11.dist-info/LICENSE +21 -0
  63. genhpf-1.0.11.dist-info/METADATA +202 -0
  64. genhpf-1.0.11.dist-info/RECORD +67 -0
  65. genhpf-1.0.11.dist-info/WHEEL +5 -0
  66. genhpf-1.0.11.dist-info/entry_points.txt +6 -0
  67. genhpf-1.0.11.dist-info/top_level.txt +1 -0
@@ -0,0 +1,374 @@
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import bisect
7
+ import time
8
+ from collections import OrderedDict
9
+ from typing import Dict, Optional
10
+
11
+ import torch
12
+ import numpy as np
13
+ from sklearn.metrics import average_precision_score, roc_auc_score
14
+
15
+ def warn(*args, **kwargs):
16
+ pass
17
+ import warnings
18
+ warnings.warn = warn
19
+
20
+ def type_as(a, b):
21
+ if torch.is_tensor(a) and torch.is_tensor(b):
22
+ return a.to(b)
23
+ else:
24
+ return a
25
+
26
+ class Meter(object):
27
+ """Base class for Meters."""
28
+
29
+ def __init__(self):
30
+ pass
31
+
32
+ def state_dict(self):
33
+ return {}
34
+
35
+ def load_state_dict(self, state_dict):
36
+ pass
37
+
38
+ def reset(self):
39
+ raise NotImplementedError
40
+
41
+ @property
42
+ def smoothed_value(self) -> float:
43
+ """Smoothed value used for logging."""
44
+ raise NotImplementedError
45
+
46
+ def safe_round(number, ndigits):
47
+ if hasattr(number, "__round__"):
48
+ return round(number, ndigits)
49
+ elif torch.is_tensor(number) and number.numel() == 1:
50
+ return safe_round(number.item(), ndigits)
51
+ elif np.ndim(number) == 0 and hasattr(number, "item"):
52
+ return safe_round(number.item(), ndigits)
53
+ else:
54
+ return number
55
+
56
+ class SumMeter(Meter):
57
+ """Computes and stores the sum"""
58
+
59
+ def __init__(self, round: Optional[int] = None):
60
+ self.round = round
61
+ self.reset()
62
+
63
+ def reset(self):
64
+ self.sum = 0 # sum from all updates
65
+
66
+ def update(self, val):
67
+ if val is not None:
68
+ self.sum = type_as(self.sum, val) + val
69
+
70
+ def state_dict(self):
71
+ return {
72
+ "sum": self.sum,
73
+ "round": self.round,
74
+ }
75
+
76
+ def load_state_dict(self, state_dict):
77
+ self.sum = state_dict["sum"]
78
+ self.round = state_dict.get("round", None)
79
+
80
+ @property
81
+ def smoothed_value(self) -> float:
82
+ val = self.sum
83
+ if self.round is not None and val is not None:
84
+ val = safe_round(val, self.round)
85
+ return val
86
+
87
+ class AverageMeter(Meter):
88
+ """Computes and stores the average and current value"""
89
+
90
+ def __init__(self, round: Optional[int] = None):
91
+ self.round = round
92
+ self.reset()
93
+
94
+ def reset(self):
95
+ self.val = None # most recent update
96
+ self.sum = 0 # sum from all updates
97
+ self.count = 0 # total n from all updates
98
+
99
+ def update(self, val, n = 1):
100
+ if val is not None:
101
+ self.val = val
102
+ if n > 0:
103
+ self.sum = type_as(self.sum, val) + (val * n)
104
+ self.count = type_as(self.count, n) + n
105
+
106
+ def state_dict(self):
107
+ return {
108
+ "val" : self.val,
109
+ "sum" : self.sum,
110
+ "count" : self.count,
111
+ "round" : self.round
112
+ }
113
+
114
+ def load_state_dict(self, state_dict):
115
+ self.val = state_dict["val"]
116
+ self.sum = state_dict["sum"]
117
+ self.count = state_dict["count"]
118
+ self.round = state_dict.get("round", None)
119
+
120
+ @property
121
+ def avg(self):
122
+ return self.sum / self.count if self.count > 0 else self.val
123
+
124
+ @property
125
+ def smoothed_value(self) -> float:
126
+ val = self.avg
127
+ if self.round is not None and val is not None:
128
+ val = safe_round(val, self.round)
129
+ return val
130
+
131
+ class TimeMeter(Meter):
132
+ """Compute the average occurrence of some event per second"""
133
+
134
+ def __init__(
135
+ self,
136
+ init: int = 0,
137
+ n: int = 0,
138
+ round: Optional[int] = None
139
+ ):
140
+ self.round = round
141
+ self.reset(init, n)
142
+
143
+ def reset(self, init = 0, n = 0):
144
+ self.init = init
145
+ self.start = time.perf_counter()
146
+ self.n = n
147
+ self.i = 0
148
+
149
+ def update(self, val = 1):
150
+ self.n = type_as(self.n, val) + val
151
+ self.i += 1
152
+
153
+ def state_dict(self):
154
+ return {
155
+ "init": self.elapsed_time,
156
+ "n": self.n,
157
+ "round": self.round
158
+ }
159
+
160
+ def load_state_dict(self, state_dict):
161
+ if "start" in state_dict:
162
+ # backwards compatibility for old state_dicts
163
+ self.reset(init = state_dict["init"])
164
+ else:
165
+ self.reset(init = state_dict["init"], n = state_dict["n"])
166
+ self.round = state_dict.get("round", None)
167
+
168
+ @property
169
+ def avg(self):
170
+ return self.n / self.elapsed_time
171
+
172
+ @property
173
+ def elapsed_time(self):
174
+ return self.init + (time.perf_counter() - self.start)
175
+
176
+ @property
177
+ def smoothed_value(self) -> float:
178
+ val = self.avg
179
+ if self.round is not None and val is not None:
180
+ val = safe_round(val, self.round)
181
+ return val
182
+
183
+
184
+ class AUCMeter(Meter):
185
+ "Stores scores / targets to compute AUROC and AUPRC"
186
+
187
+ def __init__(self,):
188
+ self.reset()
189
+
190
+ def reset(self):
191
+ self.scores = []
192
+ self.targets = []
193
+
194
+ def update(self, prob, target):
195
+ if torch.is_tensor(prob):
196
+ prob = prob.cpu().numpy()
197
+ if torch.is_tensor(target):
198
+ target = target.cpu().numpy()
199
+
200
+ self.scores.append(prob)
201
+ self.targets.append(target)
202
+
203
+ def state_dict(self):
204
+ return {
205
+ "scores": self.scores,
206
+ "targets": self.targets,
207
+ }
208
+
209
+ def load_state_dict(self, state_dict):
210
+ self.scores = state_dict["scores"]
211
+ self.targets = state_dict["targets"]
212
+ self.round = state_dict.get("round", None)
213
+
214
+ @property
215
+ def auroc(self):
216
+ y_true = np.concatenate(self.targets)
217
+ y_score = np.concatenate(self.scores)
218
+ # if y_true.shape != y_score.shape:
219
+ # y_true = np.eye(y_score.shape[1])[y_true]
220
+ if y_true.shape[0] > 127 and len(y_true.shape) >1:
221
+ mask = (y_true.sum(axis=0)!=0)
222
+ y_true = y_true[:, mask]
223
+ y_score = y_score[:, mask]
224
+ try:
225
+ return roc_auc_score(y_true=y_true, y_score=y_score, average='macro')
226
+ except ValueError:
227
+ return float("nan")
228
+
229
+ @property
230
+ def auprc(self):
231
+ y_true = np.concatenate(self.targets)
232
+ y_score = np.concatenate(self.scores)
233
+ if y_true.shape != y_score.shape:
234
+ y_true = np.eye(y_score.shape[1])[y_true]
235
+ try:
236
+ return average_precision_score(y_true=y_true, y_score=y_score, average='micro')
237
+ except ValueError:
238
+ return float("nan")
239
+
240
+ @property
241
+ def smoothed_value(self) -> float:
242
+ raise AttributeError(
243
+ "AUC meter cannot have smoothed values. Please "
244
+ "make sure the key of this meter starts with '_'."
245
+ )
246
+
247
+
248
+ class StopwatchMeter(Meter):
249
+ """Computes the sum/avg duration of some event in seconds"""
250
+
251
+ def __init__(self, round: Optional[int] = None):
252
+ self.round = round
253
+ self.sum = 0
254
+ self.n = 0
255
+ self.start_time = None
256
+
257
+ def start(self):
258
+ self.start_time = time.perf_counter()
259
+
260
+ def stop(self, n = 1, prehook = None):
261
+ if self.start_time is not None:
262
+ if prehook is not None:
263
+ prehook()
264
+ delta = time.perf_counter() - self.start_time
265
+ self.sum = self.sum + delta
266
+ self.n = type_as(self.n, n) + n
267
+
268
+ def reset(self):
269
+ self.sum = 0 # cumulative time during which stopwatch was active
270
+ self.n = 0 # total n across all start/stop
271
+ self.start()
272
+
273
+ def state_dict(self):
274
+ return {
275
+ "sum": self.sum,
276
+ "n": self.n,
277
+ "round": self.round
278
+ }
279
+
280
+ def load_state_dict(self, state_dict):
281
+ self.sum = state_dict["sum"]
282
+ self.n = state_dict["n"]
283
+ self.start_time = None
284
+ self.round = state_dict.get("round", None)
285
+
286
+ @property
287
+ def avg(self):
288
+ return self.sum / self.n if self.n > 0 else self.sum
289
+
290
+ @property
291
+ def elapsed_time(self):
292
+ if self.start_time is None:
293
+ return 0.0
294
+ return time.perf_counter() - self.start_time
295
+
296
+ @property
297
+ def smoothed_value(self) -> float:
298
+ val = self.avg if self.sum > 0 else self.elapsed_time
299
+ if self.round is not None and val is not None:
300
+ val = safe_round(val, self.round)
301
+ return val
302
+
303
+ class MetersDict(OrderedDict):
304
+ """A sorted dictionary of :class:`Meters`.
305
+
306
+ Meters are sorted according to a priority that is given when the
307
+ meter is first added to the dictionary.
308
+ """
309
+
310
+ def __init__(self, *args, **kwargs):
311
+ super().__init__(*args, **kwargs)
312
+ self.priorities = []
313
+
314
+ def __setitem__(self, key, value):
315
+ assert key not in self, "MetersDict doesn't support reassignment"
316
+ priority, value = value
317
+ bisect.insort(self.priorities, (priority, len(self.priorities), key))
318
+ super().__setitem__(key, value)
319
+ for _, _, key in self.priorities: # reorder dict to match priorities
320
+ self.move_to_end(key)
321
+
322
+ def add_meter(self, key, meter, priority):
323
+ self.__setitem__(key, (priority, meter))
324
+
325
+ def state_dict(self):
326
+ return [
327
+ (pri, key, self[key].__class__.__name__, self[key].state_dict())
328
+ for pri, _, key in self.priorities
329
+ # can't serialize DerivedMeter instances
330
+ if not isinstance(self[key], MetersDict._DerivedMeter)
331
+ ]
332
+
333
+ def load_state_dict(self, state_dict):
334
+ self.clear()
335
+ self.priorities.clear()
336
+ for pri, key, meter_cls, meter_state in state_dict:
337
+ meter = globals()[meter_cls]()
338
+ meter.load_state_dict(meter_state)
339
+ self.add_meter(key, meter, pri)
340
+
341
+ def get_smoothed_value(self, key: str) -> float:
342
+ """Get a single smoothed value."""
343
+ meter = self[key]
344
+ if isinstance(meter, MetersDict._DerivedMeter):
345
+ # print("hello: ", key, meter.fn(self))
346
+ return meter.fn(self)
347
+ else:
348
+ return meter.smoothed_value
349
+
350
+ def get_smoothed_values(self) -> Dict[str, float]:
351
+ """Get all smoothed values."""
352
+ return OrderedDict(
353
+ [
354
+ (key, self.get_smoothed_value(key))
355
+ for key in self.keys()
356
+ if not key.startswith("_")
357
+ ]
358
+ )
359
+
360
+ def reset(self):
361
+ """Reset Meter instances."""
362
+ for meter in self.values():
363
+ if isinstance(meter, MetersDict._DerivedMeter):
364
+ continue
365
+ meter.reset()
366
+
367
+ class _DerivedMeter(Meter):
368
+ """A Meter whose values are derived from other Meters."""
369
+
370
+ def __init__(self, fn):
371
+ self.fn = fn
372
+
373
+ def reset(self):
374
+ pass
@@ -0,0 +1,155 @@
1
+ import contextlib
2
+ import uuid
3
+ from collections import defaultdict, OrderedDict
4
+ from typing import Callable, List, Optional
5
+
6
+ from .meters import *
7
+
8
+ _aggregators = OrderedDict()
9
+ _active_aggregators = OrderedDict()
10
+ _active_aggregators_cnt = defaultdict(lambda: 0)
11
+
12
+ def reset() -> None:
13
+ _aggregators.clear()
14
+ _active_aggregators.clear()
15
+ _active_aggregators_cnt.clear()
16
+
17
+ _aggregators['default'] = MetersDict()
18
+ _active_aggregators['default'] = _aggregators['default']
19
+ _active_aggregators_cnt['default'] = 1
20
+
21
+ reset()
22
+
23
+ @contextlib.contextmanager
24
+ def aggregate(name=None, new_root=False):
25
+ if name is None:
26
+ name = str(uuid.uuid4())
27
+ assert name not in _aggregators
28
+ agg = MetersDict()
29
+ else:
30
+ assert name != 'default'
31
+ agg = _aggregators.setdefault(name, MetersDict())
32
+
33
+ if new_root:
34
+ backup_aggregators = _active_aggregators.copy()
35
+ _active_aggregators.clear()
36
+ backup_aggregators_cnt = _active_aggregators_cnt.copy()
37
+ _active_aggregators_cnt.clear()
38
+
39
+ _active_aggregators[name] = agg
40
+ _active_aggregators_cnt[name] += 1
41
+
42
+ yield agg
43
+
44
+ _active_aggregators_cnt[name] -= 1
45
+ if _active_aggregators_cnt[name] == 0 and name in _active_aggregators:
46
+ del _active_aggregators[name]
47
+
48
+ if new_root:
49
+ _active_aggregators.clear()
50
+ _active_aggregators.update(backup_aggregators)
51
+ _active_aggregators_cnt.clear()
52
+ _active_aggregators_cnt.update(backup_aggregators_cnt)
53
+
54
+ def get_active_aggregators() -> List[MetersDict]:
55
+ return list(_active_aggregators.values())
56
+
57
+ def log_scalar(
58
+ key,
59
+ value,
60
+ weight=1,
61
+ priority=10,
62
+ round=None
63
+ ):
64
+ for agg in get_active_aggregators():
65
+ if key not in agg:
66
+ agg.add_meter(key, AverageMeter(round=round), priority)
67
+ agg[key].update(value, weight)
68
+
69
+ def log_scalar_sum(
70
+ key,
71
+ value,
72
+ priority=10,
73
+ round=None
74
+ ):
75
+ for agg in get_active_aggregators():
76
+ if key not in agg:
77
+ agg.add_meter(key, SumMeter(round=round), priority)
78
+ agg[key].update(value)
79
+
80
+ def log_derived(
81
+ key,
82
+ fn: Callable[[MetersDict], float],
83
+ priority=20
84
+ ):
85
+ for agg in get_active_aggregators():
86
+ if key not in agg:
87
+ agg.add_meter(key, MetersDict._DerivedMeter(fn), priority)
88
+
89
+ def log_speed(
90
+ key,
91
+ value,
92
+ priority=30,
93
+ round=None
94
+ ):
95
+ for agg in get_active_aggregators():
96
+ if key not in agg:
97
+ agg.add_meter(key, TimeMeter(round=round), priority)
98
+ agg[key].reset()
99
+ else:
100
+ agg[key].update(value)
101
+
102
+ def log_start_time(
103
+ key,
104
+ priority=40,
105
+ round=None
106
+ ):
107
+ for agg in get_active_aggregators():
108
+ if key not in agg:
109
+ agg.add_meter(key, StopwatchMeter(round=round), priority)
110
+ agg[key].start()
111
+
112
+ def log_stop_time(
113
+ key,
114
+ weight=0.0,
115
+ prehook=None
116
+ ):
117
+ for agg in get_active_aggregators():
118
+ if key in agg:
119
+ agg[key].stop(weight, prehook)
120
+
121
+ def log_custom(
122
+ new_meter_fn: Callable[[], Meter],
123
+ key,
124
+ *args,
125
+ priority=50,
126
+ **kwargs,
127
+ ):
128
+ for agg in get_active_aggregators():
129
+ if key not in agg:
130
+ agg.add_meter(key, new_meter_fn(), priority)
131
+ agg[key].update(*args, **kwargs)
132
+
133
+ def reset_meter(name, key) -> None:
134
+ meter = get_meter(name, key)
135
+ if meter is not None:
136
+ meter.reset()
137
+
138
+ def reset_meters(name) -> None:
139
+ meters = get_meters(name)
140
+ if meters is not None:
141
+ meters.reset()
142
+
143
+ def get_meter(name, key) -> Meter:
144
+ if name not in _aggregators:
145
+ return None
146
+ return _aggregators[name].get(key, None)
147
+
148
+ def get_meters(name) -> MetersDict:
149
+ return _aggregators.get(name, None)
150
+
151
+ def get_smoothed_values(name) -> Dict[str, float]:
152
+ return _aggregators[name].get_smoothed_values()
153
+
154
+ def state_dict():
155
+ return OrderedDict([(name, agg.state_dict()) for name, agg in _aggregators.items()])