genhpf 1.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genhpf/__init__.py +9 -0
- genhpf/configs/__init__.py +23 -0
- genhpf/configs/config.yaml +8 -0
- genhpf/configs/configs.py +240 -0
- genhpf/configs/constants.py +29 -0
- genhpf/configs/initialize.py +58 -0
- genhpf/configs/utils.py +29 -0
- genhpf/criterions/__init__.py +74 -0
- genhpf/criterions/binary_cross_entropy.py +114 -0
- genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
- genhpf/criterions/criterion.py +87 -0
- genhpf/criterions/cross_entropy.py +202 -0
- genhpf/criterions/multi_task_criterion.py +177 -0
- genhpf/criterions/simclr_criterion.py +84 -0
- genhpf/criterions/wav2vec2_criterion.py +130 -0
- genhpf/datasets/__init__.py +84 -0
- genhpf/datasets/dataset.py +109 -0
- genhpf/datasets/genhpf_dataset.py +451 -0
- genhpf/datasets/meds_dataset.py +232 -0
- genhpf/loggings/__init__.py +0 -0
- genhpf/loggings/meters.py +374 -0
- genhpf/loggings/metrics.py +155 -0
- genhpf/loggings/progress_bar.py +445 -0
- genhpf/models/__init__.py +73 -0
- genhpf/models/genhpf.py +244 -0
- genhpf/models/genhpf_mlm.py +64 -0
- genhpf/models/genhpf_predictor.py +73 -0
- genhpf/models/genhpf_simclr.py +58 -0
- genhpf/models/genhpf_wav2vec2.py +304 -0
- genhpf/modules/__init__.py +15 -0
- genhpf/modules/gather_layer.py +23 -0
- genhpf/modules/grad_multiply.py +12 -0
- genhpf/modules/gumbel_vector_quantizer.py +204 -0
- genhpf/modules/identity_layer.py +8 -0
- genhpf/modules/layer_norm.py +27 -0
- genhpf/modules/positional_encoding.py +24 -0
- genhpf/scripts/__init__.py +0 -0
- genhpf/scripts/preprocess/__init__.py +0 -0
- genhpf/scripts/preprocess/genhpf/README.md +75 -0
- genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
- genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
- genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
- genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
- genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
- genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
- genhpf/scripts/preprocess/genhpf/main.py +175 -0
- genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
- genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
- genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
- genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
- genhpf/scripts/preprocess/manifest.py +83 -0
- genhpf/scripts/preprocess/preprocess_meds.py +674 -0
- genhpf/scripts/test.py +264 -0
- genhpf/scripts/train.py +365 -0
- genhpf/trainer.py +370 -0
- genhpf/utils/checkpoint_utils.py +171 -0
- genhpf/utils/data_utils.py +130 -0
- genhpf/utils/distributed_utils.py +497 -0
- genhpf/utils/file_io.py +170 -0
- genhpf/utils/pdb.py +38 -0
- genhpf/utils/utils.py +204 -0
- genhpf-1.0.11.dist-info/LICENSE +21 -0
- genhpf-1.0.11.dist-info/METADATA +202 -0
- genhpf-1.0.11.dist-info/RECORD +67 -0
- genhpf-1.0.11.dist-info/WHEEL +5 -0
- genhpf-1.0.11.dist-info/entry_points.txt +6 -0
- genhpf-1.0.11.dist-info/top_level.txt +1 -0
genhpf/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from .configs import (
|
|
4
|
+
BaseConfig,
|
|
5
|
+
Config,
|
|
6
|
+
CommonConfig,
|
|
7
|
+
DistributedTrainingConfig,
|
|
8
|
+
DatasetConfig,
|
|
9
|
+
CheckpointConfig,
|
|
10
|
+
)
|
|
11
|
+
from .constants import ChoiceEnum
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"BaseConfig",
|
|
17
|
+
"Config",
|
|
18
|
+
"CommonConfig",
|
|
19
|
+
"DistributedTrainingConfig",
|
|
20
|
+
"DatasetConfig",
|
|
21
|
+
"CheckpointConfig",
|
|
22
|
+
"ChoiceEnum",
|
|
23
|
+
]
|
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
from dataclasses import _MISSING_TYPE, dataclass, field
|
|
2
|
+
from typing import Any, List, Optional
|
|
3
|
+
|
|
4
|
+
from omegaconf import MISSING
|
|
5
|
+
|
|
6
|
+
from genhpf.configs.constants import LOG_FORMAT_CHOICES
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class BaseConfig:
|
|
11
|
+
"""base configuration class"""
|
|
12
|
+
|
|
13
|
+
_name: Optional[str] = None
|
|
14
|
+
|
|
15
|
+
@staticmethod
|
|
16
|
+
def name():
|
|
17
|
+
return None
|
|
18
|
+
|
|
19
|
+
def _get_all_attributes(self) -> List[str]:
|
|
20
|
+
return [k for k in self.__dataclass_fields__.keys()]
|
|
21
|
+
|
|
22
|
+
def _get_meta(self, attribute_name: str, meta: str, default: Optional[Any] = None) -> Any:
|
|
23
|
+
return self.__dataclass_fields__[attribute_name].metadata.get(meta, default)
|
|
24
|
+
|
|
25
|
+
def _get_name(self, attribute_name: str) -> str:
|
|
26
|
+
return self.__dataclass_fields__[attribute_name].name
|
|
27
|
+
|
|
28
|
+
def _get_default(self, attribute_name: str) -> Any:
|
|
29
|
+
if hasattr(self, attribute_name):
|
|
30
|
+
if str(getattr(self, attribute_name)).startswith("${"):
|
|
31
|
+
return str(getattr(self, attribute_name))
|
|
32
|
+
elif str(self.__dataclass_fields__[attribute_name].default).startswith("${"):
|
|
33
|
+
return str(self.__dataclass_fields__[attribute_name].default)
|
|
34
|
+
elif getattr(self, attribute_name) != self.__dataclass_fields__[attribute_name].default:
|
|
35
|
+
return getattr(self, attribute_name)
|
|
36
|
+
|
|
37
|
+
f = self.__dataclass_fields__[attribute_name]
|
|
38
|
+
if not isinstance(f.default_factory, _MISSING_TYPE):
|
|
39
|
+
return f.default_factory()
|
|
40
|
+
return f.default
|
|
41
|
+
|
|
42
|
+
def _get_type(self, attribute_name: str) -> Any:
|
|
43
|
+
return self.__dataclass_fields__[attribute_name].type
|
|
44
|
+
|
|
45
|
+
def _get_help(self, attribute_name: str) -> Any:
|
|
46
|
+
return self._get_meta(attribute_name, "help")
|
|
47
|
+
|
|
48
|
+
def _get_argparse_const(self, attribute_name: str) -> Any:
|
|
49
|
+
return self._get_meta(attribute_name, "argparse_const")
|
|
50
|
+
|
|
51
|
+
def _get_argparse_alias(self, attribute_name: str) -> Any:
|
|
52
|
+
return self._get_meta(attribute_name, "argparse_alias")
|
|
53
|
+
|
|
54
|
+
def _get_choices(self, attribute_name: str) -> Any:
|
|
55
|
+
return self._get_meta(attribute_name, "choices")
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@dataclass
|
|
59
|
+
class CommonConfig(BaseConfig):
|
|
60
|
+
debug: bool = field(default=False, metadata={"help": "enable debug mode"})
|
|
61
|
+
no_progress_bar: bool = field(default=False, metadata={"help": "disable progress bar"})
|
|
62
|
+
log_interval: int = field(default=100, metadata={"help": "log progress every N batches"})
|
|
63
|
+
log_format: Optional[LOG_FORMAT_CHOICES] = field(default=None, metadata={"help": "log format to use"})
|
|
64
|
+
log_file: Optional[str] = field(default=None, metadata={"help": "log file to copy metrics to."})
|
|
65
|
+
wandb_project: Optional[str] = field(
|
|
66
|
+
default=None, metadata={"help": "Weights and Biases project name to use for logging"}
|
|
67
|
+
)
|
|
68
|
+
wandb_entity: Optional[str] = field(
|
|
69
|
+
default=None, metadata={"help": "Weights and Biases entity(team) name to use for logging"}
|
|
70
|
+
)
|
|
71
|
+
seed: int = field(default=42, metadata={"help": "random seed"})
|
|
72
|
+
all_gather_list_size: int = field(
|
|
73
|
+
default=32768,
|
|
74
|
+
metadata={"help": "number of bytes reserved for gathering stats from workers"},
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@dataclass
|
|
79
|
+
class DistributedTrainingConfig(BaseConfig):
|
|
80
|
+
distributed_world_size: int = field(default=1, metadata={"help": "total number of GPUs across all nodes"})
|
|
81
|
+
distributed_rank: Optional[int] = field(default=0, metadata={"help": "rank of the current worker"})
|
|
82
|
+
distributed_backend: str = field(default="nccl", metadata={"help": "distributed backend"})
|
|
83
|
+
distributed_init_method: Optional[str] = field(
|
|
84
|
+
default=None,
|
|
85
|
+
metadata={"help": "typically tcp://hostname:port that will be used to " "init distributed training"},
|
|
86
|
+
)
|
|
87
|
+
distributed_port: int = field(default=12355, metadata={"help": "port number for distributed training"})
|
|
88
|
+
device_id: int = field(default=0, metadata={"help": "which GPU to use"})
|
|
89
|
+
bucket_cap_mb: int = field(default=25, metadata={"help": "bucket size for reduction"})
|
|
90
|
+
find_unused_parameters: bool = field(
|
|
91
|
+
default=False, metadata={"help": "disable unused parameter detection when using distributed training"}
|
|
92
|
+
)
|
|
93
|
+
broadcast_buffers: bool = field(
|
|
94
|
+
default=False,
|
|
95
|
+
metadata={
|
|
96
|
+
"help": "Copy non-trainable parameters between GPUs, such as " "batchnorm population statistics"
|
|
97
|
+
},
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@dataclass
|
|
102
|
+
class DatasetConfig(BaseConfig):
|
|
103
|
+
data_format: str = field(
|
|
104
|
+
default="genhpf", metadata={"help": "data format. supported formats: genhpf, meds"}
|
|
105
|
+
)
|
|
106
|
+
data: str = field(default=MISSING, metadata={"help": "path to the data directory"})
|
|
107
|
+
label: bool = field(default=False, metadata={"help": "whether to load labels from the dataset"})
|
|
108
|
+
vocab_size: int = field(default=MISSING, metadata={"help": "size of the vocabulary"})
|
|
109
|
+
pad_token_id: int = field(default=0, metadata={"help": "pad token id"})
|
|
110
|
+
sep_token_id: int = field(default=102, metadata={"help": "sep token id"})
|
|
111
|
+
dummy_token_id: int = field(default=101, metadata={"help": "dummy token id"})
|
|
112
|
+
ignore_index: int = field(
|
|
113
|
+
default=-100,
|
|
114
|
+
metadata={
|
|
115
|
+
"help": "specifies a target value that is ignored and does not contribute to "
|
|
116
|
+
"the input gradient. only applied to cross-entropy loss"
|
|
117
|
+
},
|
|
118
|
+
)
|
|
119
|
+
apply_mask: bool = field(default=False, metadata={"help": "whether to apply masking to the input tokens"})
|
|
120
|
+
mask_prob: float = field(default=0.15, metadata={"help": "probability for masking tokens"})
|
|
121
|
+
mask_unit: str = field(
|
|
122
|
+
default="token", metadata={"help": "unit for masking. supported units: token, event"}
|
|
123
|
+
)
|
|
124
|
+
mask_token_id: int = field(default=103, metadata={"help": "mask token id"})
|
|
125
|
+
num_workers: int = field(default=1, metadata={"help": "how many subprocesses to use for data loading"})
|
|
126
|
+
batch_size: Optional[int] = field(default=None, metadata={"help": "number of examples in a batch"})
|
|
127
|
+
train_subset: str = field(
|
|
128
|
+
default="train", metadata={"help": "data subset name to use for training (e.g., train, valid, test)"}
|
|
129
|
+
)
|
|
130
|
+
valid_subset: Optional[str] = field(
|
|
131
|
+
default="valid", metadata={"help": "comma separated list of data subset names to use for validation"}
|
|
132
|
+
)
|
|
133
|
+
test_subset: Optional[str] = field(
|
|
134
|
+
default="test", metadata={"help": "comma separated list of data subset names to use for test"}
|
|
135
|
+
)
|
|
136
|
+
combine_train_subsets: Optional[bool] = field(
|
|
137
|
+
default=None,
|
|
138
|
+
metadata={
|
|
139
|
+
"help": "whether to combine all training subsets into one dataset",
|
|
140
|
+
"argparse_alias": "--combine-train",
|
|
141
|
+
},
|
|
142
|
+
)
|
|
143
|
+
combine_valid_subsets: Optional[bool] = field(
|
|
144
|
+
default=None,
|
|
145
|
+
metadata={
|
|
146
|
+
"help": "whether to combine all validation subsets into one dataset",
|
|
147
|
+
"argparse_alias": "--combine-val",
|
|
148
|
+
},
|
|
149
|
+
)
|
|
150
|
+
combine_test_subsets: Optional[bool] = field(
|
|
151
|
+
default=None,
|
|
152
|
+
metadata={
|
|
153
|
+
"help": "whether to combine all test subsets into one dataset",
|
|
154
|
+
"argparse_alias": "--combine-test",
|
|
155
|
+
},
|
|
156
|
+
)
|
|
157
|
+
disable_validation: Optional[bool] = field(
|
|
158
|
+
default=None, metadata={"help": "whether to disable validation during training"}
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
@dataclass
|
|
163
|
+
class OptimizationConfig(BaseConfig):
|
|
164
|
+
max_epoch: int = field(default=0, metadata={"help": "maximum number of epochs to train"})
|
|
165
|
+
lr: float = field(default=1e-4, metadata={"help": "learning rate"})
|
|
166
|
+
adam_betas: Any = field(default=(0.9, 0.999), metadata={"help": "betas for Adam optimizer"})
|
|
167
|
+
adam_eps: float = field(default=1e-8, metadata={"help": "epsilon for Adam optimizer"})
|
|
168
|
+
weight_decay: float = field(default=1e-8, metadata={"help": "weight decay for Adam optimizer"})
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
@dataclass
|
|
172
|
+
class CheckpointConfig(BaseConfig):
|
|
173
|
+
save_dir: str = field(default="checkpoints", metadata={"help": "path to save checkpoints"})
|
|
174
|
+
checkpoint_prefix: str = field(
|
|
175
|
+
default="checkpoint", metadata={"help": "prefix to add to the checkpoint file name"}
|
|
176
|
+
)
|
|
177
|
+
checkpoint_suffix: str = field(default="", metadata={"help": "suffix to add to the checkpoint file name"})
|
|
178
|
+
load_checkpoint: Optional[str] = field(
|
|
179
|
+
default=None, metadata={"help": "path to a checkpoint to load model weights from, if provided"}
|
|
180
|
+
)
|
|
181
|
+
no_save: bool = field(default=False, metadata={"help": "don't save checkpoints"})
|
|
182
|
+
save_interval: int = field(default=1, metadata={"help": "save a checkpoint every N epochs"})
|
|
183
|
+
keep_last_epochs: int = field(default=-1, metadata={"help": "keep last N epoch checkpoints"})
|
|
184
|
+
no_last_checkpoints: bool = field(default=False, metadata={"help": "don't store last checkpoints"})
|
|
185
|
+
best_checkpoint_metric: str = field(
|
|
186
|
+
default="loss", metadata={"help": 'metric to use for saving "best" checkpoints'}
|
|
187
|
+
)
|
|
188
|
+
maximize_best_checkpoint_metric: bool = field(
|
|
189
|
+
default=False,
|
|
190
|
+
metadata={"help": 'select the largest metric value for saving "best" checkpoints'},
|
|
191
|
+
)
|
|
192
|
+
patience: int = field(
|
|
193
|
+
default=-1,
|
|
194
|
+
metadata={
|
|
195
|
+
"help": (
|
|
196
|
+
"early stop training if valid performance doesn't "
|
|
197
|
+
"improve for N consecutive validation runs; note "
|
|
198
|
+
"that this is influenced by --validate-interval"
|
|
199
|
+
)
|
|
200
|
+
},
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
@dataclass
|
|
205
|
+
class MEDSConfig(BaseConfig):
|
|
206
|
+
output_predictions: bool = field(
|
|
207
|
+
default=False,
|
|
208
|
+
metadata={
|
|
209
|
+
"help": "whether to output predictions. if turned on, `genhpf-test` will automatically "
|
|
210
|
+
"output predictions of the test set specified by `dataset.test_subset` in the "
|
|
211
|
+
"`meds.output_dir` directory."
|
|
212
|
+
},
|
|
213
|
+
)
|
|
214
|
+
labels_dir: Optional[str] = field(
|
|
215
|
+
default=None,
|
|
216
|
+
metadata={
|
|
217
|
+
"help": "a path to the label directory for MEDS dataset. this is required to store "
|
|
218
|
+
"output predictions in the format expected by `meds-evaluation` when "
|
|
219
|
+
"`meds.output_predictions` is turned on."
|
|
220
|
+
},
|
|
221
|
+
)
|
|
222
|
+
output_dir: Optional[str] = field(
|
|
223
|
+
default=None,
|
|
224
|
+
metadata={
|
|
225
|
+
"help": "a path to the output directory to store output predictions. "
|
|
226
|
+
"this is only used when `meds.output_predictions` is turned on."
|
|
227
|
+
},
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
@dataclass
|
|
232
|
+
class Config(BaseConfig):
|
|
233
|
+
common: CommonConfig = field(default_factory=CommonConfig)
|
|
234
|
+
distributed_training: DistributedTrainingConfig = field(default_factory=DistributedTrainingConfig)
|
|
235
|
+
dataset: DatasetConfig = field(default_factory=DatasetConfig)
|
|
236
|
+
optimization: OptimizationConfig = field(default_factory=OptimizationConfig)
|
|
237
|
+
checkpoint: CheckpointConfig = field(default_factory=CheckpointConfig)
|
|
238
|
+
meds: MEDSConfig = field(default_factory=MEDSConfig)
|
|
239
|
+
model: Any = MISSING
|
|
240
|
+
criterion: Any = None
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from enum import Enum, EnumMeta
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
class StrEnumMeta(EnumMeta):
|
|
5
|
+
# this is workaround for submitit pickling leading to instance checks failing in hydra for StrEnum, see
|
|
6
|
+
# https://github.com/facebookresearch/hydra/issues/1156
|
|
7
|
+
@classmethod
|
|
8
|
+
def __instancecheck__(cls, other):
|
|
9
|
+
return "enum" in str(type(other))
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class StrEnum(Enum, metaclass=StrEnumMeta):
|
|
13
|
+
def __str__(self):
|
|
14
|
+
return self.value
|
|
15
|
+
|
|
16
|
+
def __eq__(self, other: str):
|
|
17
|
+
return self.value == other
|
|
18
|
+
|
|
19
|
+
def __repr__(self):
|
|
20
|
+
return self.value
|
|
21
|
+
|
|
22
|
+
def __hash__(self):
|
|
23
|
+
return hash(str(self))
|
|
24
|
+
|
|
25
|
+
def ChoiceEnum(choices: List[str]):
|
|
26
|
+
"""return the Enum class used to enforce list of choices"""
|
|
27
|
+
return StrEnum("Choices", {k: k for k in choices})
|
|
28
|
+
|
|
29
|
+
LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm", "csv"])
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
from dataclasses import _MISSING_TYPE
|
|
8
|
+
from hydra.core.config_store import ConfigStore
|
|
9
|
+
from omegaconf import DictConfig, OmegaConf
|
|
10
|
+
|
|
11
|
+
from genhpf.configs import Config
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
def hydra_init(cfg_name="config") -> None:
|
|
16
|
+
cs = ConfigStore.instance()
|
|
17
|
+
cs.store(name=cfg_name, node=Config)
|
|
18
|
+
|
|
19
|
+
for k in Config.__dataclass_fields__:
|
|
20
|
+
v = Config.__dataclass_fields__[k].default
|
|
21
|
+
if isinstance(v, _MISSING_TYPE):
|
|
22
|
+
v = Config.__dataclass_fields__[k].default_factory
|
|
23
|
+
if not isinstance(v, _MISSING_TYPE):
|
|
24
|
+
v = v()
|
|
25
|
+
try:
|
|
26
|
+
cs.store(name = k, node = v)
|
|
27
|
+
except BaseException:
|
|
28
|
+
logger.error(f"{k} - {v}")
|
|
29
|
+
raise
|
|
30
|
+
|
|
31
|
+
def add_defaults(cfg: DictConfig) -> None:
|
|
32
|
+
"""This function adds default values that are stored in dataclasses that hydra doesn't know about """
|
|
33
|
+
|
|
34
|
+
from genhpf.criterions import CRITERION_DATACLASS_REGISTRY
|
|
35
|
+
from genhpf.models import MODEL_DATACLASS_REGISTRY
|
|
36
|
+
from genhpf.configs.utils import merge_with_parent
|
|
37
|
+
from typing import Any
|
|
38
|
+
|
|
39
|
+
OmegaConf.set_struct(cfg, False)
|
|
40
|
+
|
|
41
|
+
for k, v in Config.__dataclass_fields__.items():
|
|
42
|
+
field_cfg = cfg.get(k)
|
|
43
|
+
if field_cfg is not None and v.type == Any:
|
|
44
|
+
dc = None
|
|
45
|
+
|
|
46
|
+
if isinstance(field_cfg, str):
|
|
47
|
+
field_cfg = DictConfig({"_name": field_cfg})
|
|
48
|
+
field_cfg.__dict__["_parent"] = field_cfg.__dict__["_parent"]
|
|
49
|
+
|
|
50
|
+
name = field_cfg.get("_name")
|
|
51
|
+
|
|
52
|
+
if k == "model":
|
|
53
|
+
dc = MODEL_DATACLASS_REGISTRY.get(name)
|
|
54
|
+
elif k == "criterion":
|
|
55
|
+
dc = CRITERION_DATACLASS_REGISTRY.get(name)
|
|
56
|
+
|
|
57
|
+
if dc is not None:
|
|
58
|
+
cfg[k] = merge_with_parent(dc, field_cfg)
|
genhpf/configs/utils.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from dataclasses import is_dataclass
|
|
2
|
+
from omegaconf import OmegaConf, open_dict
|
|
3
|
+
|
|
4
|
+
from genhpf.configs import BaseConfig
|
|
5
|
+
|
|
6
|
+
def merge_with_parent(dc: BaseConfig, cfg: BaseConfig, remove_missing=False):
|
|
7
|
+
if remove_missing:
|
|
8
|
+
|
|
9
|
+
def remove_missing_rec(src_keys, target_cfg):
|
|
10
|
+
if is_dataclass(target_cfg):
|
|
11
|
+
target_keys = set(target_cfg.__dataclass_fields__.keys())
|
|
12
|
+
else:
|
|
13
|
+
target_keys = set(target_cfg.keys())
|
|
14
|
+
|
|
15
|
+
for k in list(src_keys.keys()):
|
|
16
|
+
if k not in target_keys:
|
|
17
|
+
del src_keys[k]
|
|
18
|
+
elif OmegaConf.is_config(src_keys[k]):
|
|
19
|
+
tgt = getattr(target_cfg, k)
|
|
20
|
+
if tgt is not None and (is_dataclass(tgt) or hasattr(tgt, "keys")):
|
|
21
|
+
remove_missing_rec(src_keys[k], tgt)
|
|
22
|
+
|
|
23
|
+
with open_dict(cfg):
|
|
24
|
+
remove_missing_rec(cfg, dc)
|
|
25
|
+
|
|
26
|
+
merged_cfg = OmegaConf.merge(dc, cfg)
|
|
27
|
+
merged_cfg.__dict__["_parent"] = cfg.__dict__["_parent"]
|
|
28
|
+
OmegaConf.set_struct(merged_cfg, True)
|
|
29
|
+
return merged_cfg
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
from hydra.core.config_store import ConfigStore
|
|
5
|
+
|
|
6
|
+
from genhpf.criterions.criterion import BaseCriterion #noqa
|
|
7
|
+
from genhpf.configs.utils import merge_with_parent
|
|
8
|
+
|
|
9
|
+
CRITERION_REGISTRY = {}
|
|
10
|
+
CRITERION_DATACLASS_REGISTRY = {}
|
|
11
|
+
|
|
12
|
+
def build_criterion(cfg) -> BaseCriterion:
|
|
13
|
+
criterion = None
|
|
14
|
+
criterion_type = getattr(cfg, "_name", None)
|
|
15
|
+
|
|
16
|
+
if criterion_type in CRITERION_REGISTRY:
|
|
17
|
+
criterion = CRITERION_REGISTRY[criterion_type]
|
|
18
|
+
# set defaults from dataclass
|
|
19
|
+
dc = CRITERION_DATACLASS_REGISTRY[criterion_type]
|
|
20
|
+
cfg = merge_with_parent(dc(), cfg)
|
|
21
|
+
|
|
22
|
+
assert criterion is not None, (
|
|
23
|
+
f"Could not infer criterion type from {str(criterion_type)}. "
|
|
24
|
+
+ "Available criterions: "
|
|
25
|
+
+ str(CRITERION_REGISTRY.keys())
|
|
26
|
+
+ " Requested criterion type: "
|
|
27
|
+
+ str(criterion_type)
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
return criterion.build_criterion(cfg)
|
|
31
|
+
|
|
32
|
+
def register_criterion(name, dataclass=None):
|
|
33
|
+
"""
|
|
34
|
+
New criterion types can be added with the :func:`register_criterion`
|
|
35
|
+
function decorator.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
name (str): the name of the criterion
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def register_criterion_cls(cls):
|
|
42
|
+
if name in CRITERION_REGISTRY:
|
|
43
|
+
raise ValueError(f"Cannot register duplicate criterion ({name})")
|
|
44
|
+
if not issubclass(cls, BaseCriterion):
|
|
45
|
+
raise ValueError(
|
|
46
|
+
f"Criterion ({name}: {cls.__name__}) must extend Base Criterion"
|
|
47
|
+
)
|
|
48
|
+
CRITERION_REGISTRY[name] = cls
|
|
49
|
+
if dataclass is not None:
|
|
50
|
+
CRITERION_DATACLASS_REGISTRY[name] = dataclass
|
|
51
|
+
|
|
52
|
+
cs = ConfigStore.instance()
|
|
53
|
+
node = dataclass()
|
|
54
|
+
node._name = name
|
|
55
|
+
cs.store(name=name, group="criterion", node=node, provider="genhpf")
|
|
56
|
+
|
|
57
|
+
return cls
|
|
58
|
+
|
|
59
|
+
return register_criterion_cls
|
|
60
|
+
|
|
61
|
+
def import_criterions(criterions_dir, namespace):
|
|
62
|
+
for file in os.listdir(criterions_dir):
|
|
63
|
+
path = os.path.join(criterions_dir, file)
|
|
64
|
+
if (
|
|
65
|
+
not file.startswith("_")
|
|
66
|
+
and not file.startswith(".")
|
|
67
|
+
and (file.endswith(".py") or os.path.isdir(path))
|
|
68
|
+
):
|
|
69
|
+
criterion_name = file[: file.find(".py")] if file.endswith(".py") else file
|
|
70
|
+
importlib.import_module(namespace + "." + criterion_name)
|
|
71
|
+
|
|
72
|
+
# automatically import any Python files in the criterions/ directory
|
|
73
|
+
criterions_dir = os.path.dirname(__file__)
|
|
74
|
+
import_criterions(criterions_dir, "genhpf.criterions")
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import Any, Dict, List, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
|
|
9
|
+
import genhpf.utils.utils as utils
|
|
10
|
+
from genhpf.criterions import BaseCriterion, register_criterion
|
|
11
|
+
from genhpf.criterions.criterion import CriterionConfig
|
|
12
|
+
from genhpf.loggings import meters, metrics
|
|
13
|
+
from genhpf.loggings.meters import safe_round
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class BinaryCrossEntropyConfig(CriterionConfig):
|
|
18
|
+
threshold: float = field(default=0.5, metadata={"help": "threshold value for binary classification"})
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@register_criterion("binary_cross_entropy", dataclass=BinaryCrossEntropyConfig)
|
|
22
|
+
class BinaryCrossEntropy(BaseCriterion):
|
|
23
|
+
def __init__(self, cfg: BinaryCrossEntropyConfig):
|
|
24
|
+
super().__init__(cfg)
|
|
25
|
+
|
|
26
|
+
if self.task_names is not None and len(self.task_names) > 1:
|
|
27
|
+
raise ValueError(
|
|
28
|
+
"binary_cross_entropy only supports single task training. if you want "
|
|
29
|
+
" to train multiple tasks, use multi_task_criterion instead."
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
self.threshold = cfg.threshold
|
|
33
|
+
|
|
34
|
+
def compute_loss(
|
|
35
|
+
self, logits: torch.Tensor, targets: torch.Tensor, sample=None, net_output=None, model=None
|
|
36
|
+
) -> Tuple[torch.Tensor, List[float]]:
|
|
37
|
+
logits = logits.flatten()
|
|
38
|
+
targets = targets.float()
|
|
39
|
+
probs = torch.sigmoid(logits)
|
|
40
|
+
loss = F.binary_cross_entropy(input=probs, target=targets, reduction="sum")
|
|
41
|
+
return loss, [loss.detach().item()]
|
|
42
|
+
|
|
43
|
+
def get_sample_size(self, sample, targets: torch.Tensor) -> int:
|
|
44
|
+
if "sample_size" in sample:
|
|
45
|
+
sample_size = sample["sample_size"]
|
|
46
|
+
else:
|
|
47
|
+
sample_size = targets.numel()
|
|
48
|
+
return sample_size
|
|
49
|
+
|
|
50
|
+
def get_logging_outputs(
|
|
51
|
+
self, logging_output, logits: torch.Tensor, targets: torch.Tensor, sample=None
|
|
52
|
+
) -> List[Dict[str, Any]]:
|
|
53
|
+
with torch.no_grad():
|
|
54
|
+
probs = torch.sigmoid(logits.flatten())
|
|
55
|
+
outputs = probs > self.threshold
|
|
56
|
+
|
|
57
|
+
if probs.numel() == 0:
|
|
58
|
+
corr = 0
|
|
59
|
+
count = 0
|
|
60
|
+
else:
|
|
61
|
+
count = float(probs.numel())
|
|
62
|
+
corr = (outputs == targets).sum().item()
|
|
63
|
+
|
|
64
|
+
logging_output["correct"] = corr
|
|
65
|
+
logging_output["count"] = count
|
|
66
|
+
|
|
67
|
+
# report aucs only in eval mode
|
|
68
|
+
if not self.training:
|
|
69
|
+
logging_output["_y_true"] = targets.cpu().numpy()
|
|
70
|
+
logging_output["_y_score"] = probs.cpu().numpy()
|
|
71
|
+
|
|
72
|
+
return logging_output
|
|
73
|
+
|
|
74
|
+
@staticmethod
|
|
75
|
+
def reduce_metrics(logging_outputs: List[Dict[str, Any]], prefix: str = None) -> None:
|
|
76
|
+
"""Aggregate logging outputs from data parallel training."""
|
|
77
|
+
if prefix is None:
|
|
78
|
+
prefix = ""
|
|
79
|
+
elif prefix is not None and not prefix.endswith("_"):
|
|
80
|
+
prefix = prefix + "_"
|
|
81
|
+
|
|
82
|
+
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
|
|
83
|
+
|
|
84
|
+
sample_size = utils.item(sum(log.get("sample_size", 0) for log in logging_outputs))
|
|
85
|
+
|
|
86
|
+
metrics.log_scalar(f"{prefix}loss", loss_sum / (sample_size or 1) / math.log(2), sample_size, round=3)
|
|
87
|
+
|
|
88
|
+
if "_y_true" in logging_outputs[0] and "_y_score" in logging_outputs[0]:
|
|
89
|
+
y_true = np.concatenate([log.get("_y_true", []) for log in logging_outputs])
|
|
90
|
+
y_score = np.concatenate([log.get("_y_score", []) for log in logging_outputs])
|
|
91
|
+
|
|
92
|
+
metrics.log_custom(meters.AUCMeter, f"_{prefix}auc", y_score, y_true)
|
|
93
|
+
|
|
94
|
+
correct = sum(log.get("correct", 0) for log in logging_outputs)
|
|
95
|
+
metrics.log_scalar(f"_{prefix}correct", correct)
|
|
96
|
+
|
|
97
|
+
total = sum(log.get("count", 0) for log in logging_outputs)
|
|
98
|
+
metrics.log_scalar(f"_{prefix}total", total)
|
|
99
|
+
|
|
100
|
+
if total > 0:
|
|
101
|
+
metrics.log_derived(
|
|
102
|
+
f"{prefix}accuracy",
|
|
103
|
+
lambda meters: safe_round(meters[f"_{prefix}correct"].sum / meters[f"_{prefix}total"].sum, 5)
|
|
104
|
+
if meters[f"_{prefix}total"].sum > 0
|
|
105
|
+
else float("nan"),
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
def post_validate(self, stats, agg, **kwargs):
|
|
109
|
+
for key in agg.keys():
|
|
110
|
+
if key.startswith("_") and key.endswith("auc"):
|
|
111
|
+
stats[key[1:-3] + "auroc"] = agg[key].auroc
|
|
112
|
+
stats[key[1:-3] + "auprc"] = agg[key].auprc
|
|
113
|
+
|
|
114
|
+
return stats
|