genhpf 1.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. genhpf/__init__.py +9 -0
  2. genhpf/configs/__init__.py +23 -0
  3. genhpf/configs/config.yaml +8 -0
  4. genhpf/configs/configs.py +240 -0
  5. genhpf/configs/constants.py +29 -0
  6. genhpf/configs/initialize.py +58 -0
  7. genhpf/configs/utils.py +29 -0
  8. genhpf/criterions/__init__.py +74 -0
  9. genhpf/criterions/binary_cross_entropy.py +114 -0
  10. genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
  11. genhpf/criterions/criterion.py +87 -0
  12. genhpf/criterions/cross_entropy.py +202 -0
  13. genhpf/criterions/multi_task_criterion.py +177 -0
  14. genhpf/criterions/simclr_criterion.py +84 -0
  15. genhpf/criterions/wav2vec2_criterion.py +130 -0
  16. genhpf/datasets/__init__.py +84 -0
  17. genhpf/datasets/dataset.py +109 -0
  18. genhpf/datasets/genhpf_dataset.py +451 -0
  19. genhpf/datasets/meds_dataset.py +232 -0
  20. genhpf/loggings/__init__.py +0 -0
  21. genhpf/loggings/meters.py +374 -0
  22. genhpf/loggings/metrics.py +155 -0
  23. genhpf/loggings/progress_bar.py +445 -0
  24. genhpf/models/__init__.py +73 -0
  25. genhpf/models/genhpf.py +244 -0
  26. genhpf/models/genhpf_mlm.py +64 -0
  27. genhpf/models/genhpf_predictor.py +73 -0
  28. genhpf/models/genhpf_simclr.py +58 -0
  29. genhpf/models/genhpf_wav2vec2.py +304 -0
  30. genhpf/modules/__init__.py +15 -0
  31. genhpf/modules/gather_layer.py +23 -0
  32. genhpf/modules/grad_multiply.py +12 -0
  33. genhpf/modules/gumbel_vector_quantizer.py +204 -0
  34. genhpf/modules/identity_layer.py +8 -0
  35. genhpf/modules/layer_norm.py +27 -0
  36. genhpf/modules/positional_encoding.py +24 -0
  37. genhpf/scripts/__init__.py +0 -0
  38. genhpf/scripts/preprocess/__init__.py +0 -0
  39. genhpf/scripts/preprocess/genhpf/README.md +75 -0
  40. genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
  41. genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
  42. genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
  43. genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
  44. genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
  45. genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
  46. genhpf/scripts/preprocess/genhpf/main.py +175 -0
  47. genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
  48. genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
  49. genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
  50. genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
  51. genhpf/scripts/preprocess/manifest.py +83 -0
  52. genhpf/scripts/preprocess/preprocess_meds.py +674 -0
  53. genhpf/scripts/test.py +264 -0
  54. genhpf/scripts/train.py +365 -0
  55. genhpf/trainer.py +370 -0
  56. genhpf/utils/checkpoint_utils.py +171 -0
  57. genhpf/utils/data_utils.py +130 -0
  58. genhpf/utils/distributed_utils.py +497 -0
  59. genhpf/utils/file_io.py +170 -0
  60. genhpf/utils/pdb.py +38 -0
  61. genhpf/utils/utils.py +204 -0
  62. genhpf-1.0.11.dist-info/LICENSE +21 -0
  63. genhpf-1.0.11.dist-info/METADATA +202 -0
  64. genhpf-1.0.11.dist-info/RECORD +67 -0
  65. genhpf-1.0.11.dist-info/WHEEL +5 -0
  66. genhpf-1.0.11.dist-info/entry_points.txt +6 -0
  67. genhpf-1.0.11.dist-info/top_level.txt +1 -0
genhpf/__init__.py ADDED
@@ -0,0 +1,9 @@
1
+ """isort:skip_file"""
2
+
3
+ from importlib.metadata import PackageNotFoundError, version
4
+
5
+ __package_name__ = "genhpf"
6
+ try:
7
+ __version__ = version(__package_name__)
8
+ except PackageNotFoundError:
9
+ __version__ = "unknown"
@@ -0,0 +1,23 @@
1
+ import logging
2
+
3
+ from .configs import (
4
+ BaseConfig,
5
+ Config,
6
+ CommonConfig,
7
+ DistributedTrainingConfig,
8
+ DatasetConfig,
9
+ CheckpointConfig,
10
+ )
11
+ from .constants import ChoiceEnum
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ __all__ = [
16
+ "BaseConfig",
17
+ "Config",
18
+ "CommonConfig",
19
+ "DistributedTrainingConfig",
20
+ "DatasetConfig",
21
+ "CheckpointConfig",
22
+ "ChoiceEnum",
23
+ ]
@@ -0,0 +1,8 @@
1
+ # @package _group_
2
+
3
+ hydra:
4
+ run:
5
+ dir: .
6
+
7
+ defaults:
8
+ - ...
@@ -0,0 +1,240 @@
1
+ from dataclasses import _MISSING_TYPE, dataclass, field
2
+ from typing import Any, List, Optional
3
+
4
+ from omegaconf import MISSING
5
+
6
+ from genhpf.configs.constants import LOG_FORMAT_CHOICES
7
+
8
+
9
+ @dataclass
10
+ class BaseConfig:
11
+ """base configuration class"""
12
+
13
+ _name: Optional[str] = None
14
+
15
+ @staticmethod
16
+ def name():
17
+ return None
18
+
19
+ def _get_all_attributes(self) -> List[str]:
20
+ return [k for k in self.__dataclass_fields__.keys()]
21
+
22
+ def _get_meta(self, attribute_name: str, meta: str, default: Optional[Any] = None) -> Any:
23
+ return self.__dataclass_fields__[attribute_name].metadata.get(meta, default)
24
+
25
+ def _get_name(self, attribute_name: str) -> str:
26
+ return self.__dataclass_fields__[attribute_name].name
27
+
28
+ def _get_default(self, attribute_name: str) -> Any:
29
+ if hasattr(self, attribute_name):
30
+ if str(getattr(self, attribute_name)).startswith("${"):
31
+ return str(getattr(self, attribute_name))
32
+ elif str(self.__dataclass_fields__[attribute_name].default).startswith("${"):
33
+ return str(self.__dataclass_fields__[attribute_name].default)
34
+ elif getattr(self, attribute_name) != self.__dataclass_fields__[attribute_name].default:
35
+ return getattr(self, attribute_name)
36
+
37
+ f = self.__dataclass_fields__[attribute_name]
38
+ if not isinstance(f.default_factory, _MISSING_TYPE):
39
+ return f.default_factory()
40
+ return f.default
41
+
42
+ def _get_type(self, attribute_name: str) -> Any:
43
+ return self.__dataclass_fields__[attribute_name].type
44
+
45
+ def _get_help(self, attribute_name: str) -> Any:
46
+ return self._get_meta(attribute_name, "help")
47
+
48
+ def _get_argparse_const(self, attribute_name: str) -> Any:
49
+ return self._get_meta(attribute_name, "argparse_const")
50
+
51
+ def _get_argparse_alias(self, attribute_name: str) -> Any:
52
+ return self._get_meta(attribute_name, "argparse_alias")
53
+
54
+ def _get_choices(self, attribute_name: str) -> Any:
55
+ return self._get_meta(attribute_name, "choices")
56
+
57
+
58
+ @dataclass
59
+ class CommonConfig(BaseConfig):
60
+ debug: bool = field(default=False, metadata={"help": "enable debug mode"})
61
+ no_progress_bar: bool = field(default=False, metadata={"help": "disable progress bar"})
62
+ log_interval: int = field(default=100, metadata={"help": "log progress every N batches"})
63
+ log_format: Optional[LOG_FORMAT_CHOICES] = field(default=None, metadata={"help": "log format to use"})
64
+ log_file: Optional[str] = field(default=None, metadata={"help": "log file to copy metrics to."})
65
+ wandb_project: Optional[str] = field(
66
+ default=None, metadata={"help": "Weights and Biases project name to use for logging"}
67
+ )
68
+ wandb_entity: Optional[str] = field(
69
+ default=None, metadata={"help": "Weights and Biases entity(team) name to use for logging"}
70
+ )
71
+ seed: int = field(default=42, metadata={"help": "random seed"})
72
+ all_gather_list_size: int = field(
73
+ default=32768,
74
+ metadata={"help": "number of bytes reserved for gathering stats from workers"},
75
+ )
76
+
77
+
78
+ @dataclass
79
+ class DistributedTrainingConfig(BaseConfig):
80
+ distributed_world_size: int = field(default=1, metadata={"help": "total number of GPUs across all nodes"})
81
+ distributed_rank: Optional[int] = field(default=0, metadata={"help": "rank of the current worker"})
82
+ distributed_backend: str = field(default="nccl", metadata={"help": "distributed backend"})
83
+ distributed_init_method: Optional[str] = field(
84
+ default=None,
85
+ metadata={"help": "typically tcp://hostname:port that will be used to " "init distributed training"},
86
+ )
87
+ distributed_port: int = field(default=12355, metadata={"help": "port number for distributed training"})
88
+ device_id: int = field(default=0, metadata={"help": "which GPU to use"})
89
+ bucket_cap_mb: int = field(default=25, metadata={"help": "bucket size for reduction"})
90
+ find_unused_parameters: bool = field(
91
+ default=False, metadata={"help": "disable unused parameter detection when using distributed training"}
92
+ )
93
+ broadcast_buffers: bool = field(
94
+ default=False,
95
+ metadata={
96
+ "help": "Copy non-trainable parameters between GPUs, such as " "batchnorm population statistics"
97
+ },
98
+ )
99
+
100
+
101
+ @dataclass
102
+ class DatasetConfig(BaseConfig):
103
+ data_format: str = field(
104
+ default="genhpf", metadata={"help": "data format. supported formats: genhpf, meds"}
105
+ )
106
+ data: str = field(default=MISSING, metadata={"help": "path to the data directory"})
107
+ label: bool = field(default=False, metadata={"help": "whether to load labels from the dataset"})
108
+ vocab_size: int = field(default=MISSING, metadata={"help": "size of the vocabulary"})
109
+ pad_token_id: int = field(default=0, metadata={"help": "pad token id"})
110
+ sep_token_id: int = field(default=102, metadata={"help": "sep token id"})
111
+ dummy_token_id: int = field(default=101, metadata={"help": "dummy token id"})
112
+ ignore_index: int = field(
113
+ default=-100,
114
+ metadata={
115
+ "help": "specifies a target value that is ignored and does not contribute to "
116
+ "the input gradient. only applied to cross-entropy loss"
117
+ },
118
+ )
119
+ apply_mask: bool = field(default=False, metadata={"help": "whether to apply masking to the input tokens"})
120
+ mask_prob: float = field(default=0.15, metadata={"help": "probability for masking tokens"})
121
+ mask_unit: str = field(
122
+ default="token", metadata={"help": "unit for masking. supported units: token, event"}
123
+ )
124
+ mask_token_id: int = field(default=103, metadata={"help": "mask token id"})
125
+ num_workers: int = field(default=1, metadata={"help": "how many subprocesses to use for data loading"})
126
+ batch_size: Optional[int] = field(default=None, metadata={"help": "number of examples in a batch"})
127
+ train_subset: str = field(
128
+ default="train", metadata={"help": "data subset name to use for training (e.g., train, valid, test)"}
129
+ )
130
+ valid_subset: Optional[str] = field(
131
+ default="valid", metadata={"help": "comma separated list of data subset names to use for validation"}
132
+ )
133
+ test_subset: Optional[str] = field(
134
+ default="test", metadata={"help": "comma separated list of data subset names to use for test"}
135
+ )
136
+ combine_train_subsets: Optional[bool] = field(
137
+ default=None,
138
+ metadata={
139
+ "help": "whether to combine all training subsets into one dataset",
140
+ "argparse_alias": "--combine-train",
141
+ },
142
+ )
143
+ combine_valid_subsets: Optional[bool] = field(
144
+ default=None,
145
+ metadata={
146
+ "help": "whether to combine all validation subsets into one dataset",
147
+ "argparse_alias": "--combine-val",
148
+ },
149
+ )
150
+ combine_test_subsets: Optional[bool] = field(
151
+ default=None,
152
+ metadata={
153
+ "help": "whether to combine all test subsets into one dataset",
154
+ "argparse_alias": "--combine-test",
155
+ },
156
+ )
157
+ disable_validation: Optional[bool] = field(
158
+ default=None, metadata={"help": "whether to disable validation during training"}
159
+ )
160
+
161
+
162
+ @dataclass
163
+ class OptimizationConfig(BaseConfig):
164
+ max_epoch: int = field(default=0, metadata={"help": "maximum number of epochs to train"})
165
+ lr: float = field(default=1e-4, metadata={"help": "learning rate"})
166
+ adam_betas: Any = field(default=(0.9, 0.999), metadata={"help": "betas for Adam optimizer"})
167
+ adam_eps: float = field(default=1e-8, metadata={"help": "epsilon for Adam optimizer"})
168
+ weight_decay: float = field(default=1e-8, metadata={"help": "weight decay for Adam optimizer"})
169
+
170
+
171
+ @dataclass
172
+ class CheckpointConfig(BaseConfig):
173
+ save_dir: str = field(default="checkpoints", metadata={"help": "path to save checkpoints"})
174
+ checkpoint_prefix: str = field(
175
+ default="checkpoint", metadata={"help": "prefix to add to the checkpoint file name"}
176
+ )
177
+ checkpoint_suffix: str = field(default="", metadata={"help": "suffix to add to the checkpoint file name"})
178
+ load_checkpoint: Optional[str] = field(
179
+ default=None, metadata={"help": "path to a checkpoint to load model weights from, if provided"}
180
+ )
181
+ no_save: bool = field(default=False, metadata={"help": "don't save checkpoints"})
182
+ save_interval: int = field(default=1, metadata={"help": "save a checkpoint every N epochs"})
183
+ keep_last_epochs: int = field(default=-1, metadata={"help": "keep last N epoch checkpoints"})
184
+ no_last_checkpoints: bool = field(default=False, metadata={"help": "don't store last checkpoints"})
185
+ best_checkpoint_metric: str = field(
186
+ default="loss", metadata={"help": 'metric to use for saving "best" checkpoints'}
187
+ )
188
+ maximize_best_checkpoint_metric: bool = field(
189
+ default=False,
190
+ metadata={"help": 'select the largest metric value for saving "best" checkpoints'},
191
+ )
192
+ patience: int = field(
193
+ default=-1,
194
+ metadata={
195
+ "help": (
196
+ "early stop training if valid performance doesn't "
197
+ "improve for N consecutive validation runs; note "
198
+ "that this is influenced by --validate-interval"
199
+ )
200
+ },
201
+ )
202
+
203
+
204
+ @dataclass
205
+ class MEDSConfig(BaseConfig):
206
+ output_predictions: bool = field(
207
+ default=False,
208
+ metadata={
209
+ "help": "whether to output predictions. if turned on, `genhpf-test` will automatically "
210
+ "output predictions of the test set specified by `dataset.test_subset` in the "
211
+ "`meds.output_dir` directory."
212
+ },
213
+ )
214
+ labels_dir: Optional[str] = field(
215
+ default=None,
216
+ metadata={
217
+ "help": "a path to the label directory for MEDS dataset. this is required to store "
218
+ "output predictions in the format expected by `meds-evaluation` when "
219
+ "`meds.output_predictions` is turned on."
220
+ },
221
+ )
222
+ output_dir: Optional[str] = field(
223
+ default=None,
224
+ metadata={
225
+ "help": "a path to the output directory to store output predictions. "
226
+ "this is only used when `meds.output_predictions` is turned on."
227
+ },
228
+ )
229
+
230
+
231
+ @dataclass
232
+ class Config(BaseConfig):
233
+ common: CommonConfig = field(default_factory=CommonConfig)
234
+ distributed_training: DistributedTrainingConfig = field(default_factory=DistributedTrainingConfig)
235
+ dataset: DatasetConfig = field(default_factory=DatasetConfig)
236
+ optimization: OptimizationConfig = field(default_factory=OptimizationConfig)
237
+ checkpoint: CheckpointConfig = field(default_factory=CheckpointConfig)
238
+ meds: MEDSConfig = field(default_factory=MEDSConfig)
239
+ model: Any = MISSING
240
+ criterion: Any = None
@@ -0,0 +1,29 @@
1
+ from enum import Enum, EnumMeta
2
+ from typing import List
3
+
4
+ class StrEnumMeta(EnumMeta):
5
+ # this is workaround for submitit pickling leading to instance checks failing in hydra for StrEnum, see
6
+ # https://github.com/facebookresearch/hydra/issues/1156
7
+ @classmethod
8
+ def __instancecheck__(cls, other):
9
+ return "enum" in str(type(other))
10
+
11
+
12
+ class StrEnum(Enum, metaclass=StrEnumMeta):
13
+ def __str__(self):
14
+ return self.value
15
+
16
+ def __eq__(self, other: str):
17
+ return self.value == other
18
+
19
+ def __repr__(self):
20
+ return self.value
21
+
22
+ def __hash__(self):
23
+ return hash(str(self))
24
+
25
+ def ChoiceEnum(choices: List[str]):
26
+ """return the Enum class used to enforce list of choices"""
27
+ return StrEnum("Choices", {k: k for k in choices})
28
+
29
+ LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm", "csv"])
@@ -0,0 +1,58 @@
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ from dataclasses import _MISSING_TYPE
8
+ from hydra.core.config_store import ConfigStore
9
+ from omegaconf import DictConfig, OmegaConf
10
+
11
+ from genhpf.configs import Config
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ def hydra_init(cfg_name="config") -> None:
16
+ cs = ConfigStore.instance()
17
+ cs.store(name=cfg_name, node=Config)
18
+
19
+ for k in Config.__dataclass_fields__:
20
+ v = Config.__dataclass_fields__[k].default
21
+ if isinstance(v, _MISSING_TYPE):
22
+ v = Config.__dataclass_fields__[k].default_factory
23
+ if not isinstance(v, _MISSING_TYPE):
24
+ v = v()
25
+ try:
26
+ cs.store(name = k, node = v)
27
+ except BaseException:
28
+ logger.error(f"{k} - {v}")
29
+ raise
30
+
31
+ def add_defaults(cfg: DictConfig) -> None:
32
+ """This function adds default values that are stored in dataclasses that hydra doesn't know about """
33
+
34
+ from genhpf.criterions import CRITERION_DATACLASS_REGISTRY
35
+ from genhpf.models import MODEL_DATACLASS_REGISTRY
36
+ from genhpf.configs.utils import merge_with_parent
37
+ from typing import Any
38
+
39
+ OmegaConf.set_struct(cfg, False)
40
+
41
+ for k, v in Config.__dataclass_fields__.items():
42
+ field_cfg = cfg.get(k)
43
+ if field_cfg is not None and v.type == Any:
44
+ dc = None
45
+
46
+ if isinstance(field_cfg, str):
47
+ field_cfg = DictConfig({"_name": field_cfg})
48
+ field_cfg.__dict__["_parent"] = field_cfg.__dict__["_parent"]
49
+
50
+ name = field_cfg.get("_name")
51
+
52
+ if k == "model":
53
+ dc = MODEL_DATACLASS_REGISTRY.get(name)
54
+ elif k == "criterion":
55
+ dc = CRITERION_DATACLASS_REGISTRY.get(name)
56
+
57
+ if dc is not None:
58
+ cfg[k] = merge_with_parent(dc, field_cfg)
@@ -0,0 +1,29 @@
1
+ from dataclasses import is_dataclass
2
+ from omegaconf import OmegaConf, open_dict
3
+
4
+ from genhpf.configs import BaseConfig
5
+
6
+ def merge_with_parent(dc: BaseConfig, cfg: BaseConfig, remove_missing=False):
7
+ if remove_missing:
8
+
9
+ def remove_missing_rec(src_keys, target_cfg):
10
+ if is_dataclass(target_cfg):
11
+ target_keys = set(target_cfg.__dataclass_fields__.keys())
12
+ else:
13
+ target_keys = set(target_cfg.keys())
14
+
15
+ for k in list(src_keys.keys()):
16
+ if k not in target_keys:
17
+ del src_keys[k]
18
+ elif OmegaConf.is_config(src_keys[k]):
19
+ tgt = getattr(target_cfg, k)
20
+ if tgt is not None and (is_dataclass(tgt) or hasattr(tgt, "keys")):
21
+ remove_missing_rec(src_keys[k], tgt)
22
+
23
+ with open_dict(cfg):
24
+ remove_missing_rec(cfg, dc)
25
+
26
+ merged_cfg = OmegaConf.merge(dc, cfg)
27
+ merged_cfg.__dict__["_parent"] = cfg.__dict__["_parent"]
28
+ OmegaConf.set_struct(merged_cfg, True)
29
+ return merged_cfg
@@ -0,0 +1,74 @@
1
+ import importlib
2
+ import os
3
+
4
+ from hydra.core.config_store import ConfigStore
5
+
6
+ from genhpf.criterions.criterion import BaseCriterion #noqa
7
+ from genhpf.configs.utils import merge_with_parent
8
+
9
+ CRITERION_REGISTRY = {}
10
+ CRITERION_DATACLASS_REGISTRY = {}
11
+
12
+ def build_criterion(cfg) -> BaseCriterion:
13
+ criterion = None
14
+ criterion_type = getattr(cfg, "_name", None)
15
+
16
+ if criterion_type in CRITERION_REGISTRY:
17
+ criterion = CRITERION_REGISTRY[criterion_type]
18
+ # set defaults from dataclass
19
+ dc = CRITERION_DATACLASS_REGISTRY[criterion_type]
20
+ cfg = merge_with_parent(dc(), cfg)
21
+
22
+ assert criterion is not None, (
23
+ f"Could not infer criterion type from {str(criterion_type)}. "
24
+ + "Available criterions: "
25
+ + str(CRITERION_REGISTRY.keys())
26
+ + " Requested criterion type: "
27
+ + str(criterion_type)
28
+ )
29
+
30
+ return criterion.build_criterion(cfg)
31
+
32
+ def register_criterion(name, dataclass=None):
33
+ """
34
+ New criterion types can be added with the :func:`register_criterion`
35
+ function decorator.
36
+
37
+ Args:
38
+ name (str): the name of the criterion
39
+ """
40
+
41
+ def register_criterion_cls(cls):
42
+ if name in CRITERION_REGISTRY:
43
+ raise ValueError(f"Cannot register duplicate criterion ({name})")
44
+ if not issubclass(cls, BaseCriterion):
45
+ raise ValueError(
46
+ f"Criterion ({name}: {cls.__name__}) must extend Base Criterion"
47
+ )
48
+ CRITERION_REGISTRY[name] = cls
49
+ if dataclass is not None:
50
+ CRITERION_DATACLASS_REGISTRY[name] = dataclass
51
+
52
+ cs = ConfigStore.instance()
53
+ node = dataclass()
54
+ node._name = name
55
+ cs.store(name=name, group="criterion", node=node, provider="genhpf")
56
+
57
+ return cls
58
+
59
+ return register_criterion_cls
60
+
61
+ def import_criterions(criterions_dir, namespace):
62
+ for file in os.listdir(criterions_dir):
63
+ path = os.path.join(criterions_dir, file)
64
+ if (
65
+ not file.startswith("_")
66
+ and not file.startswith(".")
67
+ and (file.endswith(".py") or os.path.isdir(path))
68
+ ):
69
+ criterion_name = file[: file.find(".py")] if file.endswith(".py") else file
70
+ importlib.import_module(namespace + "." + criterion_name)
71
+
72
+ # automatically import any Python files in the criterions/ directory
73
+ criterions_dir = os.path.dirname(__file__)
74
+ import_criterions(criterions_dir, "genhpf.criterions")
@@ -0,0 +1,114 @@
1
+ import math
2
+ from dataclasses import dataclass, field
3
+ from typing import Any, Dict, List, Tuple
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ import genhpf.utils.utils as utils
10
+ from genhpf.criterions import BaseCriterion, register_criterion
11
+ from genhpf.criterions.criterion import CriterionConfig
12
+ from genhpf.loggings import meters, metrics
13
+ from genhpf.loggings.meters import safe_round
14
+
15
+
16
+ @dataclass
17
+ class BinaryCrossEntropyConfig(CriterionConfig):
18
+ threshold: float = field(default=0.5, metadata={"help": "threshold value for binary classification"})
19
+
20
+
21
+ @register_criterion("binary_cross_entropy", dataclass=BinaryCrossEntropyConfig)
22
+ class BinaryCrossEntropy(BaseCriterion):
23
+ def __init__(self, cfg: BinaryCrossEntropyConfig):
24
+ super().__init__(cfg)
25
+
26
+ if self.task_names is not None and len(self.task_names) > 1:
27
+ raise ValueError(
28
+ "binary_cross_entropy only supports single task training. if you want "
29
+ " to train multiple tasks, use multi_task_criterion instead."
30
+ )
31
+
32
+ self.threshold = cfg.threshold
33
+
34
+ def compute_loss(
35
+ self, logits: torch.Tensor, targets: torch.Tensor, sample=None, net_output=None, model=None
36
+ ) -> Tuple[torch.Tensor, List[float]]:
37
+ logits = logits.flatten()
38
+ targets = targets.float()
39
+ probs = torch.sigmoid(logits)
40
+ loss = F.binary_cross_entropy(input=probs, target=targets, reduction="sum")
41
+ return loss, [loss.detach().item()]
42
+
43
+ def get_sample_size(self, sample, targets: torch.Tensor) -> int:
44
+ if "sample_size" in sample:
45
+ sample_size = sample["sample_size"]
46
+ else:
47
+ sample_size = targets.numel()
48
+ return sample_size
49
+
50
+ def get_logging_outputs(
51
+ self, logging_output, logits: torch.Tensor, targets: torch.Tensor, sample=None
52
+ ) -> List[Dict[str, Any]]:
53
+ with torch.no_grad():
54
+ probs = torch.sigmoid(logits.flatten())
55
+ outputs = probs > self.threshold
56
+
57
+ if probs.numel() == 0:
58
+ corr = 0
59
+ count = 0
60
+ else:
61
+ count = float(probs.numel())
62
+ corr = (outputs == targets).sum().item()
63
+
64
+ logging_output["correct"] = corr
65
+ logging_output["count"] = count
66
+
67
+ # report aucs only in eval mode
68
+ if not self.training:
69
+ logging_output["_y_true"] = targets.cpu().numpy()
70
+ logging_output["_y_score"] = probs.cpu().numpy()
71
+
72
+ return logging_output
73
+
74
+ @staticmethod
75
+ def reduce_metrics(logging_outputs: List[Dict[str, Any]], prefix: str = None) -> None:
76
+ """Aggregate logging outputs from data parallel training."""
77
+ if prefix is None:
78
+ prefix = ""
79
+ elif prefix is not None and not prefix.endswith("_"):
80
+ prefix = prefix + "_"
81
+
82
+ loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
83
+
84
+ sample_size = utils.item(sum(log.get("sample_size", 0) for log in logging_outputs))
85
+
86
+ metrics.log_scalar(f"{prefix}loss", loss_sum / (sample_size or 1) / math.log(2), sample_size, round=3)
87
+
88
+ if "_y_true" in logging_outputs[0] and "_y_score" in logging_outputs[0]:
89
+ y_true = np.concatenate([log.get("_y_true", []) for log in logging_outputs])
90
+ y_score = np.concatenate([log.get("_y_score", []) for log in logging_outputs])
91
+
92
+ metrics.log_custom(meters.AUCMeter, f"_{prefix}auc", y_score, y_true)
93
+
94
+ correct = sum(log.get("correct", 0) for log in logging_outputs)
95
+ metrics.log_scalar(f"_{prefix}correct", correct)
96
+
97
+ total = sum(log.get("count", 0) for log in logging_outputs)
98
+ metrics.log_scalar(f"_{prefix}total", total)
99
+
100
+ if total > 0:
101
+ metrics.log_derived(
102
+ f"{prefix}accuracy",
103
+ lambda meters: safe_round(meters[f"_{prefix}correct"].sum / meters[f"_{prefix}total"].sum, 5)
104
+ if meters[f"_{prefix}total"].sum > 0
105
+ else float("nan"),
106
+ )
107
+
108
+ def post_validate(self, stats, agg, **kwargs):
109
+ for key in agg.keys():
110
+ if key.startswith("_") and key.endswith("auc"):
111
+ stats[key[1:-3] + "auroc"] = agg[key].auroc
112
+ stats[key[1:-3] + "auprc"] = agg[key].auprc
113
+
114
+ return stats