genhpf 1.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genhpf/__init__.py +9 -0
- genhpf/configs/__init__.py +23 -0
- genhpf/configs/config.yaml +8 -0
- genhpf/configs/configs.py +240 -0
- genhpf/configs/constants.py +29 -0
- genhpf/configs/initialize.py +58 -0
- genhpf/configs/utils.py +29 -0
- genhpf/criterions/__init__.py +74 -0
- genhpf/criterions/binary_cross_entropy.py +114 -0
- genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
- genhpf/criterions/criterion.py +87 -0
- genhpf/criterions/cross_entropy.py +202 -0
- genhpf/criterions/multi_task_criterion.py +177 -0
- genhpf/criterions/simclr_criterion.py +84 -0
- genhpf/criterions/wav2vec2_criterion.py +130 -0
- genhpf/datasets/__init__.py +84 -0
- genhpf/datasets/dataset.py +109 -0
- genhpf/datasets/genhpf_dataset.py +451 -0
- genhpf/datasets/meds_dataset.py +232 -0
- genhpf/loggings/__init__.py +0 -0
- genhpf/loggings/meters.py +374 -0
- genhpf/loggings/metrics.py +155 -0
- genhpf/loggings/progress_bar.py +445 -0
- genhpf/models/__init__.py +73 -0
- genhpf/models/genhpf.py +244 -0
- genhpf/models/genhpf_mlm.py +64 -0
- genhpf/models/genhpf_predictor.py +73 -0
- genhpf/models/genhpf_simclr.py +58 -0
- genhpf/models/genhpf_wav2vec2.py +304 -0
- genhpf/modules/__init__.py +15 -0
- genhpf/modules/gather_layer.py +23 -0
- genhpf/modules/grad_multiply.py +12 -0
- genhpf/modules/gumbel_vector_quantizer.py +204 -0
- genhpf/modules/identity_layer.py +8 -0
- genhpf/modules/layer_norm.py +27 -0
- genhpf/modules/positional_encoding.py +24 -0
- genhpf/scripts/__init__.py +0 -0
- genhpf/scripts/preprocess/__init__.py +0 -0
- genhpf/scripts/preprocess/genhpf/README.md +75 -0
- genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
- genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
- genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
- genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
- genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
- genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
- genhpf/scripts/preprocess/genhpf/main.py +175 -0
- genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
- genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
- genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
- genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
- genhpf/scripts/preprocess/manifest.py +83 -0
- genhpf/scripts/preprocess/preprocess_meds.py +674 -0
- genhpf/scripts/test.py +264 -0
- genhpf/scripts/train.py +365 -0
- genhpf/trainer.py +370 -0
- genhpf/utils/checkpoint_utils.py +171 -0
- genhpf/utils/data_utils.py +130 -0
- genhpf/utils/distributed_utils.py +497 -0
- genhpf/utils/file_io.py +170 -0
- genhpf/utils/pdb.py +38 -0
- genhpf/utils/utils.py +204 -0
- genhpf-1.0.11.dist-info/LICENSE +21 -0
- genhpf-1.0.11.dist-info/METADATA +202 -0
- genhpf-1.0.11.dist-info/RECORD +67 -0
- genhpf-1.0.11.dist-info/WHEEL +5 -0
- genhpf-1.0.11.dist-info/entry_points.txt +6 -0
- genhpf-1.0.11.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,445 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Wrapper around various loggers and progress bars (e.g., tqdm).
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
import sys
|
|
9
|
+
from collections import OrderedDict
|
|
10
|
+
from contextlib import contextmanager
|
|
11
|
+
from numbers import Number
|
|
12
|
+
from typing import Optional
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
|
|
16
|
+
from .meters import AverageMeter, StopwatchMeter, TimeMeter
|
|
17
|
+
|
|
18
|
+
from genhpf.utils import distributed_utils as dist_utils
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def progress_bar(
|
|
24
|
+
iterator,
|
|
25
|
+
log_format: Optional[str] = None,
|
|
26
|
+
log_interval: int = 100,
|
|
27
|
+
log_file: Optional[str] = None,
|
|
28
|
+
epoch: Optional[int] = None,
|
|
29
|
+
prefix: Optional[str] = None,
|
|
30
|
+
default_log_format: str = "tqdm",
|
|
31
|
+
wandb_project: Optional[str] = None,
|
|
32
|
+
wandb_entity: Optional[str] = None,
|
|
33
|
+
wandb_run_name: Optional[str] = None,
|
|
34
|
+
):
|
|
35
|
+
if log_format is None:
|
|
36
|
+
log_format = default_log_format
|
|
37
|
+
if log_file is not None:
|
|
38
|
+
handler = logging.FileHandler(filename=log_file)
|
|
39
|
+
logger.addHandler(handler)
|
|
40
|
+
|
|
41
|
+
if log_format == "tqdm" and not sys.stderr.isatty():
|
|
42
|
+
log_format = "simple"
|
|
43
|
+
|
|
44
|
+
if log_format == "json":
|
|
45
|
+
bar = JsonProgressBar(iterator, epoch, prefix, log_interval)
|
|
46
|
+
elif log_format == "none":
|
|
47
|
+
bar = NoopProgressBar(iterator, epoch, prefix)
|
|
48
|
+
elif log_format == "simple":
|
|
49
|
+
bar = SimpleProgressBar(iterator, epoch, prefix, log_interval)
|
|
50
|
+
elif log_format == "tqdm":
|
|
51
|
+
bar = TqdmProgressBar(iterator, epoch, prefix)
|
|
52
|
+
elif log_format == "csv":
|
|
53
|
+
bar = CsvProgressBar(iterator, epoch, prefix, log_interval)
|
|
54
|
+
else:
|
|
55
|
+
raise ValueError("Unknown log format: {}".format(log_format))
|
|
56
|
+
|
|
57
|
+
if wandb_project:
|
|
58
|
+
bar = WandBProgressBarWrapper(bar, wandb_project, wandb_entity, run_name=wandb_run_name)
|
|
59
|
+
|
|
60
|
+
return bar
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def build_progress_bar(
|
|
64
|
+
args,
|
|
65
|
+
iterator,
|
|
66
|
+
epoch: Optional[int] = None,
|
|
67
|
+
prefix: Optional[str] = None,
|
|
68
|
+
default: str = "tqdm",
|
|
69
|
+
no_progress_bar: str = "none",
|
|
70
|
+
):
|
|
71
|
+
"""Legacy wrapper that takes an argparse.Namespace."""
|
|
72
|
+
if getattr(args, "no_progress_bar", False):
|
|
73
|
+
default = no_progress_bar
|
|
74
|
+
tensorboard_logdir = None
|
|
75
|
+
return progress_bar(
|
|
76
|
+
iterator,
|
|
77
|
+
log_format=args.log_format,
|
|
78
|
+
log_interval=args.log_interval,
|
|
79
|
+
epoch=epoch,
|
|
80
|
+
prefix=prefix,
|
|
81
|
+
tensorboard_logdir=tensorboard_logdir,
|
|
82
|
+
default_log_format=default,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def format_stat(stat):
|
|
87
|
+
if isinstance(stat, Number):
|
|
88
|
+
stat = "{:g}".format(stat)
|
|
89
|
+
elif isinstance(stat, AverageMeter):
|
|
90
|
+
stat = "{:.3f}".format(stat.avg)
|
|
91
|
+
elif isinstance(stat, TimeMeter):
|
|
92
|
+
stat = "{:g}".format(round(stat.avg))
|
|
93
|
+
elif isinstance(stat, StopwatchMeter):
|
|
94
|
+
stat = "{:g}".format(round(stat.sum))
|
|
95
|
+
elif torch.is_tensor(stat):
|
|
96
|
+
stat = stat.tolist()
|
|
97
|
+
return stat
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class BaseProgressBar(object):
|
|
101
|
+
"""Abstract class for progress bars."""
|
|
102
|
+
|
|
103
|
+
def __init__(self, iterable, epoch=None, prefix=None):
|
|
104
|
+
self.iterable = iterable
|
|
105
|
+
self.n = getattr(iterable, "n", 0)
|
|
106
|
+
self.epoch = epoch
|
|
107
|
+
self.prefix = ""
|
|
108
|
+
if epoch is not None:
|
|
109
|
+
self.prefix += "epoch {:03d}".format(epoch)
|
|
110
|
+
if prefix is not None:
|
|
111
|
+
self.prefix += (" | " if self.prefix != "" else "") + prefix
|
|
112
|
+
|
|
113
|
+
def __len__(self):
|
|
114
|
+
return len(self.iterable)
|
|
115
|
+
|
|
116
|
+
def __enter__(self):
|
|
117
|
+
return self
|
|
118
|
+
|
|
119
|
+
def __exit__(self, *exc):
|
|
120
|
+
return False
|
|
121
|
+
|
|
122
|
+
def __iter__(self):
|
|
123
|
+
raise NotImplementedError
|
|
124
|
+
|
|
125
|
+
def log(self, stats, tag=None, step=None):
|
|
126
|
+
"""Log intermediate stats according to log_interval."""
|
|
127
|
+
raise NotImplementedError
|
|
128
|
+
|
|
129
|
+
def print(self, stats, tag=None, step=None):
|
|
130
|
+
"""Print end-of-epoch stats."""
|
|
131
|
+
raise NotImplementedError
|
|
132
|
+
|
|
133
|
+
def update_config(self, config):
|
|
134
|
+
"""Log latest configuration."""
|
|
135
|
+
pass
|
|
136
|
+
|
|
137
|
+
def _str_commas(self, stats):
|
|
138
|
+
return ", ".join(key + "=" + stats[key].strip() for key in stats.keys())
|
|
139
|
+
|
|
140
|
+
def _str_pipes(self, stats):
|
|
141
|
+
return " | ".join(key + " " + stats[key].strip() for key in stats.keys())
|
|
142
|
+
|
|
143
|
+
def _format_stats(self, stats):
|
|
144
|
+
postfix = OrderedDict(stats)
|
|
145
|
+
# Preprocess stats according to datatype
|
|
146
|
+
for key in postfix.keys():
|
|
147
|
+
postfix[key] = str(format_stat(postfix[key]))
|
|
148
|
+
return postfix
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
@contextmanager
|
|
152
|
+
def rename_logger(logger, new_name):
|
|
153
|
+
old_name = logger.name
|
|
154
|
+
if new_name is not None:
|
|
155
|
+
logger.name = new_name
|
|
156
|
+
yield logger
|
|
157
|
+
logger.name = old_name
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class JsonProgressBar(BaseProgressBar):
|
|
161
|
+
"""Log output in JSON format."""
|
|
162
|
+
|
|
163
|
+
def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000):
|
|
164
|
+
super().__init__(iterable, epoch, prefix)
|
|
165
|
+
self.log_interval = log_interval
|
|
166
|
+
self.i = None
|
|
167
|
+
self.size = None
|
|
168
|
+
|
|
169
|
+
def __iter__(self):
|
|
170
|
+
self.size = len(self.iterable)
|
|
171
|
+
for i, obj in enumerate(self.iterable, start=self.n):
|
|
172
|
+
self.i = i
|
|
173
|
+
yield obj
|
|
174
|
+
|
|
175
|
+
def log(self, stats, tag=None, step=None):
|
|
176
|
+
"""Log intermediate stats according to log_interval."""
|
|
177
|
+
step = step or self.i or 0
|
|
178
|
+
|
|
179
|
+
if step > 0 and self.log_interval is not None and step % self.log_interval == 0:
|
|
180
|
+
update = (
|
|
181
|
+
self.epoch - 1 + (self.i + 1) / float(self.size)
|
|
182
|
+
if self.epoch is not None
|
|
183
|
+
else None
|
|
184
|
+
)
|
|
185
|
+
stats = self._format_stats(stats, epoch=self.epoch, update=update)
|
|
186
|
+
with rename_logger(logger, tag):
|
|
187
|
+
logger.info(json.dumps(stats))
|
|
188
|
+
|
|
189
|
+
def print(self, stats, tag=None, step=None):
|
|
190
|
+
"""Print end-of-epoch stats."""
|
|
191
|
+
self.stats = stats
|
|
192
|
+
if tag is not None:
|
|
193
|
+
self.stats = OrderedDict(
|
|
194
|
+
[(tag + "_" + k, v) for k, v in self.stats.items()]
|
|
195
|
+
)
|
|
196
|
+
stats = self._format_stats(self.stats, epoch=self.epoch)
|
|
197
|
+
with rename_logger(logger, tag):
|
|
198
|
+
logger.info(json.dumps(stats))
|
|
199
|
+
|
|
200
|
+
def _format_stats(self, stats, epoch=None, update=None):
|
|
201
|
+
postfix = OrderedDict()
|
|
202
|
+
if epoch is not None:
|
|
203
|
+
postfix["epoch"] = epoch
|
|
204
|
+
if update is not None:
|
|
205
|
+
postfix["update"] = round(update, 3)
|
|
206
|
+
# Preprocess stats according to datatype
|
|
207
|
+
for key in stats.keys():
|
|
208
|
+
if not key.startswith("_"):
|
|
209
|
+
postfix[key] = format_stat(stats[key])
|
|
210
|
+
return postfix
|
|
211
|
+
|
|
212
|
+
class NoopProgressBar(BaseProgressBar):
|
|
213
|
+
"""No logging."""
|
|
214
|
+
|
|
215
|
+
def __init__(self, iterable, epoch=None, prefix=None):
|
|
216
|
+
super().__init__(iterable, epoch, prefix)
|
|
217
|
+
|
|
218
|
+
def __iter__(self):
|
|
219
|
+
for obj in self.iterable:
|
|
220
|
+
yield obj
|
|
221
|
+
|
|
222
|
+
def log(self, stats, tag=None, step=None):
|
|
223
|
+
"""Log intermediate stats according to log_interval."""
|
|
224
|
+
pass
|
|
225
|
+
|
|
226
|
+
def print(self, stats, tag=None, step=None):
|
|
227
|
+
"""Print end-of-epoch stats."""
|
|
228
|
+
pass
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
class SimpleProgressBar(BaseProgressBar):
|
|
232
|
+
"""A minimal logger for non-TTY environments."""
|
|
233
|
+
|
|
234
|
+
def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000):
|
|
235
|
+
super().__init__(iterable, epoch, prefix)
|
|
236
|
+
self.log_interval = log_interval
|
|
237
|
+
self.i = None
|
|
238
|
+
self.size = None
|
|
239
|
+
|
|
240
|
+
def __iter__(self):
|
|
241
|
+
self.size = len(self.iterable)
|
|
242
|
+
for i, obj in enumerate(self.iterable, start=self.n):
|
|
243
|
+
self.i = i
|
|
244
|
+
yield obj
|
|
245
|
+
|
|
246
|
+
def log(self, stats, tag=None, step=None):
|
|
247
|
+
"""Log intermediate stats according to log_interval."""
|
|
248
|
+
step = step or self.i or 0
|
|
249
|
+
if step > 0 and self.log_interval is not None and step % self.log_interval == 0:
|
|
250
|
+
stats = self._format_stats(stats)
|
|
251
|
+
postfix = self._str_commas(stats)
|
|
252
|
+
with rename_logger(logger, tag):
|
|
253
|
+
logger.info(
|
|
254
|
+
"{}: {:5d} / {:d} {}".format(
|
|
255
|
+
self.prefix, self.i + 1, self.size, postfix
|
|
256
|
+
)
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
def print(self, stats, tag=None, step=None):
|
|
260
|
+
"""Print end-of-epoch stats."""
|
|
261
|
+
postfix = self._str_pipes(self._format_stats(stats))
|
|
262
|
+
with rename_logger(logger, tag):
|
|
263
|
+
logger.info("{} | {}".format(self.prefix, postfix))
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
class TqdmProgressBar(BaseProgressBar):
|
|
267
|
+
"""Log to tqdm."""
|
|
268
|
+
|
|
269
|
+
def __init__(self, iterable, epoch=None, prefix=None):
|
|
270
|
+
super().__init__(iterable, epoch, prefix)
|
|
271
|
+
from tqdm import tqdm
|
|
272
|
+
|
|
273
|
+
self.tqdm = tqdm(
|
|
274
|
+
iterable,
|
|
275
|
+
self.prefix,
|
|
276
|
+
leave=False,
|
|
277
|
+
disable=(logger.getEffectiveLevel() > logging.INFO),
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
def __iter__(self):
|
|
281
|
+
return iter(self.tqdm)
|
|
282
|
+
|
|
283
|
+
def log(self, stats, tag=None, step=None):
|
|
284
|
+
"""Log intermediate stats according to log_interval."""
|
|
285
|
+
self.tqdm.set_postfix(self._format_stats(stats), refresh=False)
|
|
286
|
+
|
|
287
|
+
def print(self, stats, tag=None, step=None):
|
|
288
|
+
"""Print end-of-epoch stats."""
|
|
289
|
+
postfix = self._str_pipes(self._format_stats(stats))
|
|
290
|
+
with rename_logger(logger, tag):
|
|
291
|
+
logger.info("{} | {}".format(self.prefix, postfix))
|
|
292
|
+
|
|
293
|
+
try:
|
|
294
|
+
import csv
|
|
295
|
+
except ImportError:
|
|
296
|
+
csv = None
|
|
297
|
+
|
|
298
|
+
class CsvProgressBar(BaseProgressBar):
|
|
299
|
+
"""Log to csv."""
|
|
300
|
+
|
|
301
|
+
def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000):
|
|
302
|
+
super().__init__(iterable, epoch, prefix)
|
|
303
|
+
self.log_interval = log_interval
|
|
304
|
+
self.i = None
|
|
305
|
+
self.size = None
|
|
306
|
+
|
|
307
|
+
def __iter__(self):
|
|
308
|
+
self.size = len(self.iterable)
|
|
309
|
+
for i, obj in enumerate(self.iterable, start=self.n):
|
|
310
|
+
self.i = i
|
|
311
|
+
yield obj
|
|
312
|
+
|
|
313
|
+
def log(self, stats, tag=None, step=None):
|
|
314
|
+
"""Log intermediate stats according to csv."""
|
|
315
|
+
self._log_to_csv(stats, tag, step)
|
|
316
|
+
self._print_log(stats, tag, step)
|
|
317
|
+
|
|
318
|
+
def print(self, stats, tag=None, step=None):
|
|
319
|
+
"""Log end-of-epoch stats."""
|
|
320
|
+
self._log_to_csv(stats, tag, step)
|
|
321
|
+
self._print_log(stats, tag, step)
|
|
322
|
+
|
|
323
|
+
def _print_log(self, stats, tag=None, step=None):
|
|
324
|
+
"""Print intermediate stats."""
|
|
325
|
+
step = step or self.i or 0
|
|
326
|
+
|
|
327
|
+
if step > 0 and self.log_interval is not None and step % self.log_interval == 0:
|
|
328
|
+
stats = self._format_stats(stats)
|
|
329
|
+
postfix = self._str_commas(stats)
|
|
330
|
+
with rename_logger(logger, tag):
|
|
331
|
+
logger.info(
|
|
332
|
+
"{}: {:5d} / {:d} {}".format(
|
|
333
|
+
self.prefix, self.i + 1, self.size, postfix
|
|
334
|
+
)
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
def _log_to_csv(self, stats, tag=None, step=None):
|
|
338
|
+
if csv is None:
|
|
339
|
+
return
|
|
340
|
+
if dist_utils.get_data_parallel_world_size() > 1 and dist_utils.get_data_parallel_rank() > 0:
|
|
341
|
+
return
|
|
342
|
+
|
|
343
|
+
csv_logs = {}
|
|
344
|
+
|
|
345
|
+
if step is None:
|
|
346
|
+
csv_logs["step"] = stats["num_updates"] if "num_updates" in stats else None
|
|
347
|
+
else:
|
|
348
|
+
csv_logs["step"] = step
|
|
349
|
+
|
|
350
|
+
for key in stats.keys() - {"num_updates"}:
|
|
351
|
+
if isinstance(stats[key], AverageMeter):
|
|
352
|
+
csv_logs[key] = stats[key].val
|
|
353
|
+
elif isinstance(stats[key], Number):
|
|
354
|
+
csv_logs[key] = stats[key]
|
|
355
|
+
|
|
356
|
+
fname = "log.csv" if tag is None else tag + ".csv"
|
|
357
|
+
if not os.path.exists(fname):
|
|
358
|
+
cols = ["step"]
|
|
359
|
+
for key in csv_logs.keys() - {"step"}:
|
|
360
|
+
cols.append(key)
|
|
361
|
+
with open(fname, "w") as f:
|
|
362
|
+
wr = csv.writer(f)
|
|
363
|
+
wr.writerow(cols)
|
|
364
|
+
|
|
365
|
+
with open(fname, "r") as f:
|
|
366
|
+
rd = csv.reader(f)
|
|
367
|
+
headers = next(iter(rd))
|
|
368
|
+
|
|
369
|
+
appended_lth = 0
|
|
370
|
+
for key in csv_logs.keys():
|
|
371
|
+
if key not in headers:
|
|
372
|
+
headers.append(key)
|
|
373
|
+
appended_lth += 1
|
|
374
|
+
|
|
375
|
+
if appended_lth > 0:
|
|
376
|
+
lines = [headers]
|
|
377
|
+
with open(fname, "r") as f:
|
|
378
|
+
rd = csv.reader(f)
|
|
379
|
+
# drop headers
|
|
380
|
+
next(iter(rd))
|
|
381
|
+
for line in rd:
|
|
382
|
+
for _ in range(appended_lth):
|
|
383
|
+
line.append("")
|
|
384
|
+
lines.append(line)
|
|
385
|
+
with open(fname, "w") as f:
|
|
386
|
+
wr = csv.writer(f)
|
|
387
|
+
wr.writerows(lines)
|
|
388
|
+
|
|
389
|
+
with open(fname, "a") as f:
|
|
390
|
+
wr = csv.DictWriter(f, fieldnames=headers)
|
|
391
|
+
wr.writerow(csv_logs)
|
|
392
|
+
|
|
393
|
+
try:
|
|
394
|
+
import wandb
|
|
395
|
+
except ImportError:
|
|
396
|
+
wandb = None
|
|
397
|
+
|
|
398
|
+
class WandBProgressBarWrapper(BaseProgressBar):
|
|
399
|
+
"""Log to Weights & Biases."""
|
|
400
|
+
|
|
401
|
+
def __init__(self, wrapped_bar, wandb_project, wandb_entity, run_name=None):
|
|
402
|
+
self.wrapped_bar = wrapped_bar
|
|
403
|
+
if wandb is None:
|
|
404
|
+
logger.warning("wandb not found, pip install wandb")
|
|
405
|
+
return
|
|
406
|
+
|
|
407
|
+
# reinit=False to ensure if wandb.init() is called multiple times
|
|
408
|
+
# within one process it still references the same run
|
|
409
|
+
wandb.init(project=wandb_project, entity=wandb_entity, reinit=False, name=run_name)
|
|
410
|
+
|
|
411
|
+
def __iter__(self):
|
|
412
|
+
return iter(self.wrapped_bar)
|
|
413
|
+
|
|
414
|
+
def log(self, stats, tag=None, step=None):
|
|
415
|
+
"""Log intermediate stats to wandb."""
|
|
416
|
+
self._log_to_wandb(stats, tag, step)
|
|
417
|
+
self.wrapped_bar.log(stats, tag=tag, step=step)
|
|
418
|
+
|
|
419
|
+
def print(self, stats, tag=None, step=None):
|
|
420
|
+
"""Print end-of-epoch stats."""
|
|
421
|
+
self._log_to_wandb(stats, tag, step)
|
|
422
|
+
self.wrapped_bar.print(stats, tag=tag, step=step)
|
|
423
|
+
|
|
424
|
+
def update_config(self, config):
|
|
425
|
+
"""Log latest configuration."""
|
|
426
|
+
if wandb is not None:
|
|
427
|
+
wandb.config.update(config)
|
|
428
|
+
self.wrapped_bar.update_config(config)
|
|
429
|
+
|
|
430
|
+
def _log_to_wandb(self, stats, tag=None, step=None):
|
|
431
|
+
if wandb is None:
|
|
432
|
+
return
|
|
433
|
+
if step is None:
|
|
434
|
+
step = stats["num_updates"] if "num_updates" in stats else None
|
|
435
|
+
|
|
436
|
+
prefix = "" if tag is None else tag + "/"
|
|
437
|
+
|
|
438
|
+
wandb_logs = {}
|
|
439
|
+
for key in stats.keys() - {"num_updates"}:
|
|
440
|
+
if isinstance(stats[key], AverageMeter):
|
|
441
|
+
wandb_logs[prefix + key] = stats[key].val
|
|
442
|
+
elif isinstance(stats[key], Number):
|
|
443
|
+
wandb_logs[prefix + key] = stats[key]
|
|
444
|
+
|
|
445
|
+
wandb.log(wandb_logs, step=step)
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
from hydra.core.config_store import ConfigStore
|
|
5
|
+
|
|
6
|
+
from .genhpf import GenHPF #noqa
|
|
7
|
+
from genhpf.configs.utils import merge_with_parent
|
|
8
|
+
|
|
9
|
+
MODEL_REGISTRY = {}
|
|
10
|
+
MODEL_DATACLASS_REGISTRY = {}
|
|
11
|
+
|
|
12
|
+
def build_model(cfg):
|
|
13
|
+
model = None
|
|
14
|
+
model_type = getattr(cfg, "_name", None)
|
|
15
|
+
|
|
16
|
+
if model_type in MODEL_REGISTRY:
|
|
17
|
+
model = MODEL_REGISTRY[model_type]
|
|
18
|
+
# set defaults from dataclass
|
|
19
|
+
dc = MODEL_DATACLASS_REGISTRY[model_type]
|
|
20
|
+
cfg = merge_with_parent(dc(), cfg)
|
|
21
|
+
|
|
22
|
+
assert model is not None, (
|
|
23
|
+
f"Could not infer model type from {str(model_type)}. "
|
|
24
|
+
+ "Available models: "
|
|
25
|
+
+ str(MODEL_REGISTRY.keys())
|
|
26
|
+
+ " Requested model type: "
|
|
27
|
+
+ str(model_type)
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
model_instance = model.build_model(cfg)
|
|
31
|
+
return model_instance
|
|
32
|
+
|
|
33
|
+
def register_model(name, dataclass=None):
|
|
34
|
+
"""
|
|
35
|
+
New model types can be added with the :func:`register_model`
|
|
36
|
+
function decorator.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
name (str): the name of the model
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def register_model_cls(cls):
|
|
43
|
+
if name in MODEL_REGISTRY:
|
|
44
|
+
raise ValueError(f"Cannot register duplicate model ({name})")
|
|
45
|
+
if not issubclass(cls, GenHPF):
|
|
46
|
+
raise ValueError(f"Model ({name}: {cls.__name__}) must extend GenHPF")
|
|
47
|
+
MODEL_REGISTRY[name] = cls
|
|
48
|
+
if dataclass is not None:
|
|
49
|
+
MODEL_DATACLASS_REGISTRY[name] = dataclass
|
|
50
|
+
|
|
51
|
+
cs = ConfigStore.instance()
|
|
52
|
+
node = dataclass()
|
|
53
|
+
node._name = name
|
|
54
|
+
cs.store(name=name, group="model", node=node, provider="genhpf")
|
|
55
|
+
|
|
56
|
+
return cls
|
|
57
|
+
|
|
58
|
+
return register_model_cls
|
|
59
|
+
|
|
60
|
+
def import_models(models_dir, namespace):
|
|
61
|
+
for file in os.listdir(models_dir):
|
|
62
|
+
path = os.path.join(models_dir, file)
|
|
63
|
+
if (
|
|
64
|
+
not file.startswith("_")
|
|
65
|
+
and not file.startswith(".")
|
|
66
|
+
and (file.endswith(".py") or os.path.isdir(path))
|
|
67
|
+
):
|
|
68
|
+
model_name = file[: file.find(".py")] if file.endswith(".py") else file
|
|
69
|
+
importlib.import_module(namespace + "." + model_name)
|
|
70
|
+
|
|
71
|
+
# automatically import any Python files in the models/ directory
|
|
72
|
+
models_dir = os.path.dirname(__file__)
|
|
73
|
+
import_models(models_dir, "genhpf.models")
|