genhpf 1.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genhpf/__init__.py +9 -0
- genhpf/configs/__init__.py +23 -0
- genhpf/configs/config.yaml +8 -0
- genhpf/configs/configs.py +240 -0
- genhpf/configs/constants.py +29 -0
- genhpf/configs/initialize.py +58 -0
- genhpf/configs/utils.py +29 -0
- genhpf/criterions/__init__.py +74 -0
- genhpf/criterions/binary_cross_entropy.py +114 -0
- genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
- genhpf/criterions/criterion.py +87 -0
- genhpf/criterions/cross_entropy.py +202 -0
- genhpf/criterions/multi_task_criterion.py +177 -0
- genhpf/criterions/simclr_criterion.py +84 -0
- genhpf/criterions/wav2vec2_criterion.py +130 -0
- genhpf/datasets/__init__.py +84 -0
- genhpf/datasets/dataset.py +109 -0
- genhpf/datasets/genhpf_dataset.py +451 -0
- genhpf/datasets/meds_dataset.py +232 -0
- genhpf/loggings/__init__.py +0 -0
- genhpf/loggings/meters.py +374 -0
- genhpf/loggings/metrics.py +155 -0
- genhpf/loggings/progress_bar.py +445 -0
- genhpf/models/__init__.py +73 -0
- genhpf/models/genhpf.py +244 -0
- genhpf/models/genhpf_mlm.py +64 -0
- genhpf/models/genhpf_predictor.py +73 -0
- genhpf/models/genhpf_simclr.py +58 -0
- genhpf/models/genhpf_wav2vec2.py +304 -0
- genhpf/modules/__init__.py +15 -0
- genhpf/modules/gather_layer.py +23 -0
- genhpf/modules/grad_multiply.py +12 -0
- genhpf/modules/gumbel_vector_quantizer.py +204 -0
- genhpf/modules/identity_layer.py +8 -0
- genhpf/modules/layer_norm.py +27 -0
- genhpf/modules/positional_encoding.py +24 -0
- genhpf/scripts/__init__.py +0 -0
- genhpf/scripts/preprocess/__init__.py +0 -0
- genhpf/scripts/preprocess/genhpf/README.md +75 -0
- genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
- genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
- genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
- genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
- genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
- genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
- genhpf/scripts/preprocess/genhpf/main.py +175 -0
- genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
- genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
- genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
- genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
- genhpf/scripts/preprocess/manifest.py +83 -0
- genhpf/scripts/preprocess/preprocess_meds.py +674 -0
- genhpf/scripts/test.py +264 -0
- genhpf/scripts/train.py +365 -0
- genhpf/trainer.py +370 -0
- genhpf/utils/checkpoint_utils.py +171 -0
- genhpf/utils/data_utils.py +130 -0
- genhpf/utils/distributed_utils.py +497 -0
- genhpf/utils/file_io.py +170 -0
- genhpf/utils/pdb.py +38 -0
- genhpf/utils/utils.py +204 -0
- genhpf-1.0.11.dist-info/LICENSE +21 -0
- genhpf-1.0.11.dist-info/METADATA +202 -0
- genhpf-1.0.11.dist-info/RECORD +67 -0
- genhpf-1.0.11.dist-info/WHEEL +5 -0
- genhpf-1.0.11.dist-info/entry_points.txt +6 -0
- genhpf-1.0.11.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,374 @@
|
|
|
1
|
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
import bisect
|
|
7
|
+
import time
|
|
8
|
+
from collections import OrderedDict
|
|
9
|
+
from typing import Dict, Optional
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
import numpy as np
|
|
13
|
+
from sklearn.metrics import average_precision_score, roc_auc_score
|
|
14
|
+
|
|
15
|
+
def warn(*args, **kwargs):
|
|
16
|
+
pass
|
|
17
|
+
import warnings
|
|
18
|
+
warnings.warn = warn
|
|
19
|
+
|
|
20
|
+
def type_as(a, b):
|
|
21
|
+
if torch.is_tensor(a) and torch.is_tensor(b):
|
|
22
|
+
return a.to(b)
|
|
23
|
+
else:
|
|
24
|
+
return a
|
|
25
|
+
|
|
26
|
+
class Meter(object):
|
|
27
|
+
"""Base class for Meters."""
|
|
28
|
+
|
|
29
|
+
def __init__(self):
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
def state_dict(self):
|
|
33
|
+
return {}
|
|
34
|
+
|
|
35
|
+
def load_state_dict(self, state_dict):
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
def reset(self):
|
|
39
|
+
raise NotImplementedError
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def smoothed_value(self) -> float:
|
|
43
|
+
"""Smoothed value used for logging."""
|
|
44
|
+
raise NotImplementedError
|
|
45
|
+
|
|
46
|
+
def safe_round(number, ndigits):
|
|
47
|
+
if hasattr(number, "__round__"):
|
|
48
|
+
return round(number, ndigits)
|
|
49
|
+
elif torch.is_tensor(number) and number.numel() == 1:
|
|
50
|
+
return safe_round(number.item(), ndigits)
|
|
51
|
+
elif np.ndim(number) == 0 and hasattr(number, "item"):
|
|
52
|
+
return safe_round(number.item(), ndigits)
|
|
53
|
+
else:
|
|
54
|
+
return number
|
|
55
|
+
|
|
56
|
+
class SumMeter(Meter):
|
|
57
|
+
"""Computes and stores the sum"""
|
|
58
|
+
|
|
59
|
+
def __init__(self, round: Optional[int] = None):
|
|
60
|
+
self.round = round
|
|
61
|
+
self.reset()
|
|
62
|
+
|
|
63
|
+
def reset(self):
|
|
64
|
+
self.sum = 0 # sum from all updates
|
|
65
|
+
|
|
66
|
+
def update(self, val):
|
|
67
|
+
if val is not None:
|
|
68
|
+
self.sum = type_as(self.sum, val) + val
|
|
69
|
+
|
|
70
|
+
def state_dict(self):
|
|
71
|
+
return {
|
|
72
|
+
"sum": self.sum,
|
|
73
|
+
"round": self.round,
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
def load_state_dict(self, state_dict):
|
|
77
|
+
self.sum = state_dict["sum"]
|
|
78
|
+
self.round = state_dict.get("round", None)
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def smoothed_value(self) -> float:
|
|
82
|
+
val = self.sum
|
|
83
|
+
if self.round is not None and val is not None:
|
|
84
|
+
val = safe_round(val, self.round)
|
|
85
|
+
return val
|
|
86
|
+
|
|
87
|
+
class AverageMeter(Meter):
|
|
88
|
+
"""Computes and stores the average and current value"""
|
|
89
|
+
|
|
90
|
+
def __init__(self, round: Optional[int] = None):
|
|
91
|
+
self.round = round
|
|
92
|
+
self.reset()
|
|
93
|
+
|
|
94
|
+
def reset(self):
|
|
95
|
+
self.val = None # most recent update
|
|
96
|
+
self.sum = 0 # sum from all updates
|
|
97
|
+
self.count = 0 # total n from all updates
|
|
98
|
+
|
|
99
|
+
def update(self, val, n = 1):
|
|
100
|
+
if val is not None:
|
|
101
|
+
self.val = val
|
|
102
|
+
if n > 0:
|
|
103
|
+
self.sum = type_as(self.sum, val) + (val * n)
|
|
104
|
+
self.count = type_as(self.count, n) + n
|
|
105
|
+
|
|
106
|
+
def state_dict(self):
|
|
107
|
+
return {
|
|
108
|
+
"val" : self.val,
|
|
109
|
+
"sum" : self.sum,
|
|
110
|
+
"count" : self.count,
|
|
111
|
+
"round" : self.round
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
def load_state_dict(self, state_dict):
|
|
115
|
+
self.val = state_dict["val"]
|
|
116
|
+
self.sum = state_dict["sum"]
|
|
117
|
+
self.count = state_dict["count"]
|
|
118
|
+
self.round = state_dict.get("round", None)
|
|
119
|
+
|
|
120
|
+
@property
|
|
121
|
+
def avg(self):
|
|
122
|
+
return self.sum / self.count if self.count > 0 else self.val
|
|
123
|
+
|
|
124
|
+
@property
|
|
125
|
+
def smoothed_value(self) -> float:
|
|
126
|
+
val = self.avg
|
|
127
|
+
if self.round is not None and val is not None:
|
|
128
|
+
val = safe_round(val, self.round)
|
|
129
|
+
return val
|
|
130
|
+
|
|
131
|
+
class TimeMeter(Meter):
|
|
132
|
+
"""Compute the average occurrence of some event per second"""
|
|
133
|
+
|
|
134
|
+
def __init__(
|
|
135
|
+
self,
|
|
136
|
+
init: int = 0,
|
|
137
|
+
n: int = 0,
|
|
138
|
+
round: Optional[int] = None
|
|
139
|
+
):
|
|
140
|
+
self.round = round
|
|
141
|
+
self.reset(init, n)
|
|
142
|
+
|
|
143
|
+
def reset(self, init = 0, n = 0):
|
|
144
|
+
self.init = init
|
|
145
|
+
self.start = time.perf_counter()
|
|
146
|
+
self.n = n
|
|
147
|
+
self.i = 0
|
|
148
|
+
|
|
149
|
+
def update(self, val = 1):
|
|
150
|
+
self.n = type_as(self.n, val) + val
|
|
151
|
+
self.i += 1
|
|
152
|
+
|
|
153
|
+
def state_dict(self):
|
|
154
|
+
return {
|
|
155
|
+
"init": self.elapsed_time,
|
|
156
|
+
"n": self.n,
|
|
157
|
+
"round": self.round
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
def load_state_dict(self, state_dict):
|
|
161
|
+
if "start" in state_dict:
|
|
162
|
+
# backwards compatibility for old state_dicts
|
|
163
|
+
self.reset(init = state_dict["init"])
|
|
164
|
+
else:
|
|
165
|
+
self.reset(init = state_dict["init"], n = state_dict["n"])
|
|
166
|
+
self.round = state_dict.get("round", None)
|
|
167
|
+
|
|
168
|
+
@property
|
|
169
|
+
def avg(self):
|
|
170
|
+
return self.n / self.elapsed_time
|
|
171
|
+
|
|
172
|
+
@property
|
|
173
|
+
def elapsed_time(self):
|
|
174
|
+
return self.init + (time.perf_counter() - self.start)
|
|
175
|
+
|
|
176
|
+
@property
|
|
177
|
+
def smoothed_value(self) -> float:
|
|
178
|
+
val = self.avg
|
|
179
|
+
if self.round is not None and val is not None:
|
|
180
|
+
val = safe_round(val, self.round)
|
|
181
|
+
return val
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
class AUCMeter(Meter):
|
|
185
|
+
"Stores scores / targets to compute AUROC and AUPRC"
|
|
186
|
+
|
|
187
|
+
def __init__(self,):
|
|
188
|
+
self.reset()
|
|
189
|
+
|
|
190
|
+
def reset(self):
|
|
191
|
+
self.scores = []
|
|
192
|
+
self.targets = []
|
|
193
|
+
|
|
194
|
+
def update(self, prob, target):
|
|
195
|
+
if torch.is_tensor(prob):
|
|
196
|
+
prob = prob.cpu().numpy()
|
|
197
|
+
if torch.is_tensor(target):
|
|
198
|
+
target = target.cpu().numpy()
|
|
199
|
+
|
|
200
|
+
self.scores.append(prob)
|
|
201
|
+
self.targets.append(target)
|
|
202
|
+
|
|
203
|
+
def state_dict(self):
|
|
204
|
+
return {
|
|
205
|
+
"scores": self.scores,
|
|
206
|
+
"targets": self.targets,
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
def load_state_dict(self, state_dict):
|
|
210
|
+
self.scores = state_dict["scores"]
|
|
211
|
+
self.targets = state_dict["targets"]
|
|
212
|
+
self.round = state_dict.get("round", None)
|
|
213
|
+
|
|
214
|
+
@property
|
|
215
|
+
def auroc(self):
|
|
216
|
+
y_true = np.concatenate(self.targets)
|
|
217
|
+
y_score = np.concatenate(self.scores)
|
|
218
|
+
# if y_true.shape != y_score.shape:
|
|
219
|
+
# y_true = np.eye(y_score.shape[1])[y_true]
|
|
220
|
+
if y_true.shape[0] > 127 and len(y_true.shape) >1:
|
|
221
|
+
mask = (y_true.sum(axis=0)!=0)
|
|
222
|
+
y_true = y_true[:, mask]
|
|
223
|
+
y_score = y_score[:, mask]
|
|
224
|
+
try:
|
|
225
|
+
return roc_auc_score(y_true=y_true, y_score=y_score, average='macro')
|
|
226
|
+
except ValueError:
|
|
227
|
+
return float("nan")
|
|
228
|
+
|
|
229
|
+
@property
|
|
230
|
+
def auprc(self):
|
|
231
|
+
y_true = np.concatenate(self.targets)
|
|
232
|
+
y_score = np.concatenate(self.scores)
|
|
233
|
+
if y_true.shape != y_score.shape:
|
|
234
|
+
y_true = np.eye(y_score.shape[1])[y_true]
|
|
235
|
+
try:
|
|
236
|
+
return average_precision_score(y_true=y_true, y_score=y_score, average='micro')
|
|
237
|
+
except ValueError:
|
|
238
|
+
return float("nan")
|
|
239
|
+
|
|
240
|
+
@property
|
|
241
|
+
def smoothed_value(self) -> float:
|
|
242
|
+
raise AttributeError(
|
|
243
|
+
"AUC meter cannot have smoothed values. Please "
|
|
244
|
+
"make sure the key of this meter starts with '_'."
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
class StopwatchMeter(Meter):
|
|
249
|
+
"""Computes the sum/avg duration of some event in seconds"""
|
|
250
|
+
|
|
251
|
+
def __init__(self, round: Optional[int] = None):
|
|
252
|
+
self.round = round
|
|
253
|
+
self.sum = 0
|
|
254
|
+
self.n = 0
|
|
255
|
+
self.start_time = None
|
|
256
|
+
|
|
257
|
+
def start(self):
|
|
258
|
+
self.start_time = time.perf_counter()
|
|
259
|
+
|
|
260
|
+
def stop(self, n = 1, prehook = None):
|
|
261
|
+
if self.start_time is not None:
|
|
262
|
+
if prehook is not None:
|
|
263
|
+
prehook()
|
|
264
|
+
delta = time.perf_counter() - self.start_time
|
|
265
|
+
self.sum = self.sum + delta
|
|
266
|
+
self.n = type_as(self.n, n) + n
|
|
267
|
+
|
|
268
|
+
def reset(self):
|
|
269
|
+
self.sum = 0 # cumulative time during which stopwatch was active
|
|
270
|
+
self.n = 0 # total n across all start/stop
|
|
271
|
+
self.start()
|
|
272
|
+
|
|
273
|
+
def state_dict(self):
|
|
274
|
+
return {
|
|
275
|
+
"sum": self.sum,
|
|
276
|
+
"n": self.n,
|
|
277
|
+
"round": self.round
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
def load_state_dict(self, state_dict):
|
|
281
|
+
self.sum = state_dict["sum"]
|
|
282
|
+
self.n = state_dict["n"]
|
|
283
|
+
self.start_time = None
|
|
284
|
+
self.round = state_dict.get("round", None)
|
|
285
|
+
|
|
286
|
+
@property
|
|
287
|
+
def avg(self):
|
|
288
|
+
return self.sum / self.n if self.n > 0 else self.sum
|
|
289
|
+
|
|
290
|
+
@property
|
|
291
|
+
def elapsed_time(self):
|
|
292
|
+
if self.start_time is None:
|
|
293
|
+
return 0.0
|
|
294
|
+
return time.perf_counter() - self.start_time
|
|
295
|
+
|
|
296
|
+
@property
|
|
297
|
+
def smoothed_value(self) -> float:
|
|
298
|
+
val = self.avg if self.sum > 0 else self.elapsed_time
|
|
299
|
+
if self.round is not None and val is not None:
|
|
300
|
+
val = safe_round(val, self.round)
|
|
301
|
+
return val
|
|
302
|
+
|
|
303
|
+
class MetersDict(OrderedDict):
|
|
304
|
+
"""A sorted dictionary of :class:`Meters`.
|
|
305
|
+
|
|
306
|
+
Meters are sorted according to a priority that is given when the
|
|
307
|
+
meter is first added to the dictionary.
|
|
308
|
+
"""
|
|
309
|
+
|
|
310
|
+
def __init__(self, *args, **kwargs):
|
|
311
|
+
super().__init__(*args, **kwargs)
|
|
312
|
+
self.priorities = []
|
|
313
|
+
|
|
314
|
+
def __setitem__(self, key, value):
|
|
315
|
+
assert key not in self, "MetersDict doesn't support reassignment"
|
|
316
|
+
priority, value = value
|
|
317
|
+
bisect.insort(self.priorities, (priority, len(self.priorities), key))
|
|
318
|
+
super().__setitem__(key, value)
|
|
319
|
+
for _, _, key in self.priorities: # reorder dict to match priorities
|
|
320
|
+
self.move_to_end(key)
|
|
321
|
+
|
|
322
|
+
def add_meter(self, key, meter, priority):
|
|
323
|
+
self.__setitem__(key, (priority, meter))
|
|
324
|
+
|
|
325
|
+
def state_dict(self):
|
|
326
|
+
return [
|
|
327
|
+
(pri, key, self[key].__class__.__name__, self[key].state_dict())
|
|
328
|
+
for pri, _, key in self.priorities
|
|
329
|
+
# can't serialize DerivedMeter instances
|
|
330
|
+
if not isinstance(self[key], MetersDict._DerivedMeter)
|
|
331
|
+
]
|
|
332
|
+
|
|
333
|
+
def load_state_dict(self, state_dict):
|
|
334
|
+
self.clear()
|
|
335
|
+
self.priorities.clear()
|
|
336
|
+
for pri, key, meter_cls, meter_state in state_dict:
|
|
337
|
+
meter = globals()[meter_cls]()
|
|
338
|
+
meter.load_state_dict(meter_state)
|
|
339
|
+
self.add_meter(key, meter, pri)
|
|
340
|
+
|
|
341
|
+
def get_smoothed_value(self, key: str) -> float:
|
|
342
|
+
"""Get a single smoothed value."""
|
|
343
|
+
meter = self[key]
|
|
344
|
+
if isinstance(meter, MetersDict._DerivedMeter):
|
|
345
|
+
# print("hello: ", key, meter.fn(self))
|
|
346
|
+
return meter.fn(self)
|
|
347
|
+
else:
|
|
348
|
+
return meter.smoothed_value
|
|
349
|
+
|
|
350
|
+
def get_smoothed_values(self) -> Dict[str, float]:
|
|
351
|
+
"""Get all smoothed values."""
|
|
352
|
+
return OrderedDict(
|
|
353
|
+
[
|
|
354
|
+
(key, self.get_smoothed_value(key))
|
|
355
|
+
for key in self.keys()
|
|
356
|
+
if not key.startswith("_")
|
|
357
|
+
]
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
def reset(self):
|
|
361
|
+
"""Reset Meter instances."""
|
|
362
|
+
for meter in self.values():
|
|
363
|
+
if isinstance(meter, MetersDict._DerivedMeter):
|
|
364
|
+
continue
|
|
365
|
+
meter.reset()
|
|
366
|
+
|
|
367
|
+
class _DerivedMeter(Meter):
|
|
368
|
+
"""A Meter whose values are derived from other Meters."""
|
|
369
|
+
|
|
370
|
+
def __init__(self, fn):
|
|
371
|
+
self.fn = fn
|
|
372
|
+
|
|
373
|
+
def reset(self):
|
|
374
|
+
pass
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import uuid
|
|
3
|
+
from collections import defaultdict, OrderedDict
|
|
4
|
+
from typing import Callable, List, Optional
|
|
5
|
+
|
|
6
|
+
from .meters import *
|
|
7
|
+
|
|
8
|
+
_aggregators = OrderedDict()
|
|
9
|
+
_active_aggregators = OrderedDict()
|
|
10
|
+
_active_aggregators_cnt = defaultdict(lambda: 0)
|
|
11
|
+
|
|
12
|
+
def reset() -> None:
|
|
13
|
+
_aggregators.clear()
|
|
14
|
+
_active_aggregators.clear()
|
|
15
|
+
_active_aggregators_cnt.clear()
|
|
16
|
+
|
|
17
|
+
_aggregators['default'] = MetersDict()
|
|
18
|
+
_active_aggregators['default'] = _aggregators['default']
|
|
19
|
+
_active_aggregators_cnt['default'] = 1
|
|
20
|
+
|
|
21
|
+
reset()
|
|
22
|
+
|
|
23
|
+
@contextlib.contextmanager
|
|
24
|
+
def aggregate(name=None, new_root=False):
|
|
25
|
+
if name is None:
|
|
26
|
+
name = str(uuid.uuid4())
|
|
27
|
+
assert name not in _aggregators
|
|
28
|
+
agg = MetersDict()
|
|
29
|
+
else:
|
|
30
|
+
assert name != 'default'
|
|
31
|
+
agg = _aggregators.setdefault(name, MetersDict())
|
|
32
|
+
|
|
33
|
+
if new_root:
|
|
34
|
+
backup_aggregators = _active_aggregators.copy()
|
|
35
|
+
_active_aggregators.clear()
|
|
36
|
+
backup_aggregators_cnt = _active_aggregators_cnt.copy()
|
|
37
|
+
_active_aggregators_cnt.clear()
|
|
38
|
+
|
|
39
|
+
_active_aggregators[name] = agg
|
|
40
|
+
_active_aggregators_cnt[name] += 1
|
|
41
|
+
|
|
42
|
+
yield agg
|
|
43
|
+
|
|
44
|
+
_active_aggregators_cnt[name] -= 1
|
|
45
|
+
if _active_aggregators_cnt[name] == 0 and name in _active_aggregators:
|
|
46
|
+
del _active_aggregators[name]
|
|
47
|
+
|
|
48
|
+
if new_root:
|
|
49
|
+
_active_aggregators.clear()
|
|
50
|
+
_active_aggregators.update(backup_aggregators)
|
|
51
|
+
_active_aggregators_cnt.clear()
|
|
52
|
+
_active_aggregators_cnt.update(backup_aggregators_cnt)
|
|
53
|
+
|
|
54
|
+
def get_active_aggregators() -> List[MetersDict]:
|
|
55
|
+
return list(_active_aggregators.values())
|
|
56
|
+
|
|
57
|
+
def log_scalar(
|
|
58
|
+
key,
|
|
59
|
+
value,
|
|
60
|
+
weight=1,
|
|
61
|
+
priority=10,
|
|
62
|
+
round=None
|
|
63
|
+
):
|
|
64
|
+
for agg in get_active_aggregators():
|
|
65
|
+
if key not in agg:
|
|
66
|
+
agg.add_meter(key, AverageMeter(round=round), priority)
|
|
67
|
+
agg[key].update(value, weight)
|
|
68
|
+
|
|
69
|
+
def log_scalar_sum(
|
|
70
|
+
key,
|
|
71
|
+
value,
|
|
72
|
+
priority=10,
|
|
73
|
+
round=None
|
|
74
|
+
):
|
|
75
|
+
for agg in get_active_aggregators():
|
|
76
|
+
if key not in agg:
|
|
77
|
+
agg.add_meter(key, SumMeter(round=round), priority)
|
|
78
|
+
agg[key].update(value)
|
|
79
|
+
|
|
80
|
+
def log_derived(
|
|
81
|
+
key,
|
|
82
|
+
fn: Callable[[MetersDict], float],
|
|
83
|
+
priority=20
|
|
84
|
+
):
|
|
85
|
+
for agg in get_active_aggregators():
|
|
86
|
+
if key not in agg:
|
|
87
|
+
agg.add_meter(key, MetersDict._DerivedMeter(fn), priority)
|
|
88
|
+
|
|
89
|
+
def log_speed(
|
|
90
|
+
key,
|
|
91
|
+
value,
|
|
92
|
+
priority=30,
|
|
93
|
+
round=None
|
|
94
|
+
):
|
|
95
|
+
for agg in get_active_aggregators():
|
|
96
|
+
if key not in agg:
|
|
97
|
+
agg.add_meter(key, TimeMeter(round=round), priority)
|
|
98
|
+
agg[key].reset()
|
|
99
|
+
else:
|
|
100
|
+
agg[key].update(value)
|
|
101
|
+
|
|
102
|
+
def log_start_time(
|
|
103
|
+
key,
|
|
104
|
+
priority=40,
|
|
105
|
+
round=None
|
|
106
|
+
):
|
|
107
|
+
for agg in get_active_aggregators():
|
|
108
|
+
if key not in agg:
|
|
109
|
+
agg.add_meter(key, StopwatchMeter(round=round), priority)
|
|
110
|
+
agg[key].start()
|
|
111
|
+
|
|
112
|
+
def log_stop_time(
|
|
113
|
+
key,
|
|
114
|
+
weight=0.0,
|
|
115
|
+
prehook=None
|
|
116
|
+
):
|
|
117
|
+
for agg in get_active_aggregators():
|
|
118
|
+
if key in agg:
|
|
119
|
+
agg[key].stop(weight, prehook)
|
|
120
|
+
|
|
121
|
+
def log_custom(
|
|
122
|
+
new_meter_fn: Callable[[], Meter],
|
|
123
|
+
key,
|
|
124
|
+
*args,
|
|
125
|
+
priority=50,
|
|
126
|
+
**kwargs,
|
|
127
|
+
):
|
|
128
|
+
for agg in get_active_aggregators():
|
|
129
|
+
if key not in agg:
|
|
130
|
+
agg.add_meter(key, new_meter_fn(), priority)
|
|
131
|
+
agg[key].update(*args, **kwargs)
|
|
132
|
+
|
|
133
|
+
def reset_meter(name, key) -> None:
|
|
134
|
+
meter = get_meter(name, key)
|
|
135
|
+
if meter is not None:
|
|
136
|
+
meter.reset()
|
|
137
|
+
|
|
138
|
+
def reset_meters(name) -> None:
|
|
139
|
+
meters = get_meters(name)
|
|
140
|
+
if meters is not None:
|
|
141
|
+
meters.reset()
|
|
142
|
+
|
|
143
|
+
def get_meter(name, key) -> Meter:
|
|
144
|
+
if name not in _aggregators:
|
|
145
|
+
return None
|
|
146
|
+
return _aggregators[name].get(key, None)
|
|
147
|
+
|
|
148
|
+
def get_meters(name) -> MetersDict:
|
|
149
|
+
return _aggregators.get(name, None)
|
|
150
|
+
|
|
151
|
+
def get_smoothed_values(name) -> Dict[str, float]:
|
|
152
|
+
return _aggregators[name].get_smoothed_values()
|
|
153
|
+
|
|
154
|
+
def state_dict():
|
|
155
|
+
return OrderedDict([(name, agg.state_dict()) for name, agg in _aggregators.items()])
|