genhpf 1.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genhpf/__init__.py +9 -0
- genhpf/configs/__init__.py +23 -0
- genhpf/configs/config.yaml +8 -0
- genhpf/configs/configs.py +240 -0
- genhpf/configs/constants.py +29 -0
- genhpf/configs/initialize.py +58 -0
- genhpf/configs/utils.py +29 -0
- genhpf/criterions/__init__.py +74 -0
- genhpf/criterions/binary_cross_entropy.py +114 -0
- genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
- genhpf/criterions/criterion.py +87 -0
- genhpf/criterions/cross_entropy.py +202 -0
- genhpf/criterions/multi_task_criterion.py +177 -0
- genhpf/criterions/simclr_criterion.py +84 -0
- genhpf/criterions/wav2vec2_criterion.py +130 -0
- genhpf/datasets/__init__.py +84 -0
- genhpf/datasets/dataset.py +109 -0
- genhpf/datasets/genhpf_dataset.py +451 -0
- genhpf/datasets/meds_dataset.py +232 -0
- genhpf/loggings/__init__.py +0 -0
- genhpf/loggings/meters.py +374 -0
- genhpf/loggings/metrics.py +155 -0
- genhpf/loggings/progress_bar.py +445 -0
- genhpf/models/__init__.py +73 -0
- genhpf/models/genhpf.py +244 -0
- genhpf/models/genhpf_mlm.py +64 -0
- genhpf/models/genhpf_predictor.py +73 -0
- genhpf/models/genhpf_simclr.py +58 -0
- genhpf/models/genhpf_wav2vec2.py +304 -0
- genhpf/modules/__init__.py +15 -0
- genhpf/modules/gather_layer.py +23 -0
- genhpf/modules/grad_multiply.py +12 -0
- genhpf/modules/gumbel_vector_quantizer.py +204 -0
- genhpf/modules/identity_layer.py +8 -0
- genhpf/modules/layer_norm.py +27 -0
- genhpf/modules/positional_encoding.py +24 -0
- genhpf/scripts/__init__.py +0 -0
- genhpf/scripts/preprocess/__init__.py +0 -0
- genhpf/scripts/preprocess/genhpf/README.md +75 -0
- genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
- genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
- genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
- genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
- genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
- genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
- genhpf/scripts/preprocess/genhpf/main.py +175 -0
- genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
- genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
- genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
- genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
- genhpf/scripts/preprocess/manifest.py +83 -0
- genhpf/scripts/preprocess/preprocess_meds.py +674 -0
- genhpf/scripts/test.py +264 -0
- genhpf/scripts/train.py +365 -0
- genhpf/trainer.py +370 -0
- genhpf/utils/checkpoint_utils.py +171 -0
- genhpf/utils/data_utils.py +130 -0
- genhpf/utils/distributed_utils.py +497 -0
- genhpf/utils/file_io.py +170 -0
- genhpf/utils/pdb.py +38 -0
- genhpf/utils/utils.py +204 -0
- genhpf-1.0.11.dist-info/LICENSE +21 -0
- genhpf-1.0.11.dist-info/METADATA +202 -0
- genhpf-1.0.11.dist-info/RECORD +67 -0
- genhpf-1.0.11.dist-info/WHEEL +5 -0
- genhpf-1.0.11.dist-info/entry_points.txt +6 -0
- genhpf-1.0.11.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import Any, Dict, List, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
|
|
9
|
+
import genhpf.utils.utils as utils
|
|
10
|
+
from genhpf.criterions import BaseCriterion, register_criterion
|
|
11
|
+
from genhpf.criterions.criterion import CriterionConfig
|
|
12
|
+
from genhpf.loggings import meters, metrics
|
|
13
|
+
from genhpf.loggings.meters import safe_round
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class BinaryCrossEntropyWithLogitsConfig(CriterionConfig):
|
|
18
|
+
threshold: float = field(default=0.5, metadata={"help": "threshold value for binary classification"})
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@register_criterion("binary_cross_entropy_with_logits", dataclass=BinaryCrossEntropyWithLogitsConfig)
|
|
22
|
+
class BinaryCrossEntropyWithLogits(BaseCriterion):
|
|
23
|
+
def __init__(self, cfg: BinaryCrossEntropyWithLogitsConfig):
|
|
24
|
+
super().__init__(cfg)
|
|
25
|
+
|
|
26
|
+
if self.task_names is not None and len(self.task_names) > 1:
|
|
27
|
+
raise ValueError(
|
|
28
|
+
"binary_cross_entropy_with_logits only supports single task training."
|
|
29
|
+
" if you want to train multiple tasks, use multi_task_criterion instead."
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
self.threshold = cfg.threshold
|
|
33
|
+
|
|
34
|
+
def compute_loss(
|
|
35
|
+
self, logits: torch.Tensor, targets: torch.Tensor, sample=None, net_output=None, model=None
|
|
36
|
+
) -> Tuple[torch.Tensor, List[float]]:
|
|
37
|
+
assert (
|
|
38
|
+
logits.size() == targets.size()
|
|
39
|
+
), f"logits and targets must have the same size: {logits.size()} vs {targets.size()}"
|
|
40
|
+
targets = targets.float()
|
|
41
|
+
loss = F.binary_cross_entropy_with_logits(input=logits, target=targets, reduction="sum")
|
|
42
|
+
return loss, [loss.detach().item()]
|
|
43
|
+
|
|
44
|
+
def get_sample_size(self, sample, targets: torch.Tensor) -> int:
|
|
45
|
+
if "sample_size" in sample:
|
|
46
|
+
sample_size = sample["sample_size"]
|
|
47
|
+
else:
|
|
48
|
+
sample_size = targets.numel()
|
|
49
|
+
return sample_size
|
|
50
|
+
|
|
51
|
+
def get_logging_outputs(
|
|
52
|
+
self, logging_output, logits: torch.Tensor, targets: torch.Tensor, sample=None
|
|
53
|
+
) -> List[Dict[str, Any]]:
|
|
54
|
+
with torch.no_grad():
|
|
55
|
+
probs = torch.sigmoid(logits)
|
|
56
|
+
outputs = probs > self.threshold
|
|
57
|
+
|
|
58
|
+
if probs.numel() == 0:
|
|
59
|
+
corr = 0
|
|
60
|
+
count = 0
|
|
61
|
+
else:
|
|
62
|
+
count = float(probs.numel())
|
|
63
|
+
corr = (outputs == targets).sum().item()
|
|
64
|
+
|
|
65
|
+
logging_output["correct"] = corr
|
|
66
|
+
logging_output["count"] = count
|
|
67
|
+
|
|
68
|
+
# report aucs only in eval mode
|
|
69
|
+
if not self.training:
|
|
70
|
+
logging_output["_y_true"] = targets.cpu().numpy()
|
|
71
|
+
logging_output["_y_score"] = probs.cpu().numpy()
|
|
72
|
+
|
|
73
|
+
return logging_output
|
|
74
|
+
|
|
75
|
+
@staticmethod
|
|
76
|
+
def reduce_metrics(logging_outputs: List[Dict[str, Any]], prefix: str = None) -> None:
|
|
77
|
+
"""Aggregate logging outputs from data parallel training."""
|
|
78
|
+
if prefix is None:
|
|
79
|
+
prefix = ""
|
|
80
|
+
elif prefix is not None and not prefix.endswith("_"):
|
|
81
|
+
prefix = prefix + "_"
|
|
82
|
+
|
|
83
|
+
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
|
|
84
|
+
|
|
85
|
+
sample_size = utils.item(sum(log.get("sample_size", 0) for log in logging_outputs))
|
|
86
|
+
|
|
87
|
+
metrics.log_scalar(f"{prefix}loss", loss_sum / (sample_size or 1) / math.log(2), sample_size, round=3)
|
|
88
|
+
|
|
89
|
+
if "_y_true" in logging_outputs[0] and "_y_score" in logging_outputs[0]:
|
|
90
|
+
y_true = np.concatenate([log.get("_y_true", []) for log in logging_outputs])
|
|
91
|
+
y_score = np.concatenate([log.get("_y_score", []) for log in logging_outputs])
|
|
92
|
+
|
|
93
|
+
metrics.log_custom(meters.AUCMeter, f"_{prefix}auc", y_score, y_true)
|
|
94
|
+
|
|
95
|
+
correct = sum(log.get("correct", 0) for log in logging_outputs)
|
|
96
|
+
metrics.log_scalar(f"_{prefix}correct", correct)
|
|
97
|
+
|
|
98
|
+
total = sum(log.get("count", 0) for log in logging_outputs)
|
|
99
|
+
metrics.log_scalar(f"_{prefix}total", total)
|
|
100
|
+
|
|
101
|
+
if total > 0:
|
|
102
|
+
metrics.log_derived(
|
|
103
|
+
f"{prefix}accuracy",
|
|
104
|
+
lambda meters: safe_round(meters[f"_{prefix}correct"].sum / meters[f"_{prefix}total"].sum, 5)
|
|
105
|
+
if meters[f"_{prefix}total"].sum > 0
|
|
106
|
+
else float("nan"),
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
def post_validate(self, stats, agg, **kwargs):
|
|
110
|
+
for key in agg.keys():
|
|
111
|
+
if key.startswith("_") and key.endswith("auc"):
|
|
112
|
+
stats[key[1:-3] + "auroc"] = agg[key].auroc
|
|
113
|
+
stats[key[1:-3] + "auprc"] = agg[key].auprc
|
|
114
|
+
|
|
115
|
+
return stats
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch.nn.modules.loss import _Loss
|
|
6
|
+
|
|
7
|
+
from genhpf.configs import BaseConfig
|
|
8
|
+
from genhpf.models.genhpf import GenHPF
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class CriterionConfig(BaseConfig):
|
|
13
|
+
task_names: Optional[List[str]] = field(
|
|
14
|
+
default=None, metadata={"help": "a list of task names for multi-task learning"}
|
|
15
|
+
)
|
|
16
|
+
num_labels: Optional[List[int]] = field(
|
|
17
|
+
default=None, metadata={"help": "a list of number of labels for each task"}
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class BaseCriterion(_Loss):
|
|
22
|
+
def __init__(self, cfg: CriterionConfig):
|
|
23
|
+
super().__init__()
|
|
24
|
+
self.cfg = cfg
|
|
25
|
+
|
|
26
|
+
self.task_names = cfg.task_names
|
|
27
|
+
self.num_labels = cfg.num_labels
|
|
28
|
+
|
|
29
|
+
@classmethod
|
|
30
|
+
def build_criterion(cls, cfg: CriterionConfig):
|
|
31
|
+
"""Construct a new criterion instance."""
|
|
32
|
+
return cls(cfg)
|
|
33
|
+
|
|
34
|
+
def compute_loss(
|
|
35
|
+
self, logits: torch.Tensor, targets: torch.Tensor, sample=None, net_output=None, model=None
|
|
36
|
+
) -> Tuple[torch.Tensor, List[float]]:
|
|
37
|
+
"""Compute the loss given the logits and targets from the model."""
|
|
38
|
+
raise NotImplementedError("Criterion must implement the `compute_loss` method")
|
|
39
|
+
|
|
40
|
+
def get_sample_size(self, sample, targets: torch.Tensor) -> int:
|
|
41
|
+
"""Get the sample size, which is used as the denominator for the gradient."""
|
|
42
|
+
raise NotImplementedError("Criterion must implement the `get_sample_size` method")
|
|
43
|
+
|
|
44
|
+
def get_logging_outputs(
|
|
45
|
+
self, logging_output, logits: torch.Tensor, targets: torch.Tensor, sample=None
|
|
46
|
+
) -> List[Dict[str, Any]]:
|
|
47
|
+
"""
|
|
48
|
+
Get the logging output to display while training
|
|
49
|
+
"""
|
|
50
|
+
raise NotImplementedError("Criterion must implement the `get_logging_outputs` method")
|
|
51
|
+
|
|
52
|
+
def forward(self, model: GenHPF, sample, return_net_output=False):
|
|
53
|
+
"""Compute the loss for the given sample.
|
|
54
|
+
|
|
55
|
+
Returns a tuple with three elements:
|
|
56
|
+
1. the loss
|
|
57
|
+
2. the sample size, which is used as the denominator for the gradient
|
|
58
|
+
3. logging outputs to display while training
|
|
59
|
+
"""
|
|
60
|
+
net_output = model(**sample["net_input"])
|
|
61
|
+
logits = model.get_logits(sample, net_output)
|
|
62
|
+
targets = model.get_targets(sample, net_output)
|
|
63
|
+
|
|
64
|
+
loss, losses_to_log = self.compute_loss(
|
|
65
|
+
logits, targets, sample=sample, net_output=net_output, model=model
|
|
66
|
+
)
|
|
67
|
+
sample_size = self.get_sample_size(sample, targets)
|
|
68
|
+
|
|
69
|
+
logging_output = {}
|
|
70
|
+
if len(losses_to_log) > 1:
|
|
71
|
+
logging_output["loss"] = loss.item()
|
|
72
|
+
for i, l in enumerate(losses_to_log):
|
|
73
|
+
logging_output[f"loss_{i}"] = l
|
|
74
|
+
else:
|
|
75
|
+
logging_output["loss"] = losses_to_log[0]
|
|
76
|
+
logging_output["sample_size"] = sample_size
|
|
77
|
+
logging_output = self.get_logging_outputs(logging_output, logits, targets, sample)
|
|
78
|
+
|
|
79
|
+
if return_net_output:
|
|
80
|
+
return loss, sample_size, logging_output, net_output
|
|
81
|
+
else:
|
|
82
|
+
return loss, sample_size, logging_output
|
|
83
|
+
|
|
84
|
+
@staticmethod
|
|
85
|
+
def reduce_metrics(stats: Dict[str, Any], prefix: str = None) -> None:
|
|
86
|
+
"""Aggregate logging outputs from data parallel training."""
|
|
87
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import Any, Dict, List, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
from omegaconf import II
|
|
9
|
+
|
|
10
|
+
import genhpf.utils.utils as utils
|
|
11
|
+
from genhpf.criterions import BaseCriterion, register_criterion
|
|
12
|
+
from genhpf.criterions.criterion import CriterionConfig
|
|
13
|
+
from genhpf.loggings import meters, metrics
|
|
14
|
+
from genhpf.loggings.meters import safe_round
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class CrossEntropyConfig(CriterionConfig):
|
|
19
|
+
report_auc: bool = field(
|
|
20
|
+
default=False,
|
|
21
|
+
metadata={
|
|
22
|
+
"help": "whether to report auc. note that this is only available in eval mode and "
|
|
23
|
+
"can cause memory and performance issues if enabled."
|
|
24
|
+
},
|
|
25
|
+
)
|
|
26
|
+
ignore_index: int = II("dataset.ignore_index")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@register_criterion("cross_entropy", dataclass=CrossEntropyConfig)
|
|
30
|
+
class CrossEntropy(BaseCriterion):
|
|
31
|
+
def __init__(self, cfg: CrossEntropyConfig):
|
|
32
|
+
super().__init__(cfg)
|
|
33
|
+
|
|
34
|
+
if self.task_names is not None and len(self.task_names) > 1:
|
|
35
|
+
raise ValueError(
|
|
36
|
+
"cross_entropy only supports single task training. if you want to train multiple"
|
|
37
|
+
" tasks, use multi_task_criterion instead."
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
self.report_auc = cfg.report_auc
|
|
41
|
+
self.ignore_index = cfg.ignore_index
|
|
42
|
+
|
|
43
|
+
def compute_loss(
|
|
44
|
+
self, logits: torch.Tensor, targets: torch.Tensor, sample=None, net_output=None, model=None
|
|
45
|
+
) -> Tuple[torch.Tensor, List[float]]:
|
|
46
|
+
"""Compute the loss given the logits and targets from the model."""
|
|
47
|
+
logits = logits.view(-1, logits.size(-1))
|
|
48
|
+
targets = targets.view(-1).long()
|
|
49
|
+
|
|
50
|
+
if torch.all(targets == self.ignore_index):
|
|
51
|
+
return logits.new_tensor(0.0), [0.0]
|
|
52
|
+
|
|
53
|
+
loss = F.cross_entropy(logits, targets, reduction="sum", ignore_index=self.ignore_index)
|
|
54
|
+
|
|
55
|
+
return loss, [loss.detach().item()]
|
|
56
|
+
|
|
57
|
+
def get_sample_size(self, sample, targets: torch.Tensor) -> int:
|
|
58
|
+
if "sample_size" in sample:
|
|
59
|
+
sample_size = sample["sample_size"]
|
|
60
|
+
else:
|
|
61
|
+
sample_size = targets.numel()
|
|
62
|
+
return sample_size
|
|
63
|
+
|
|
64
|
+
def get_logging_outputs(
|
|
65
|
+
self, logging_output, logits: torch.Tensor, targets: torch.Tensor, sample=None
|
|
66
|
+
) -> List[Dict[str, Any]]:
|
|
67
|
+
with torch.no_grad():
|
|
68
|
+
logits = logits.view(-1, logits.size(-1))
|
|
69
|
+
targets = targets.view(-1).long()
|
|
70
|
+
|
|
71
|
+
valid_indices = torch.where(targets != self.ignore_index)
|
|
72
|
+
if len(valid_indices[0]) == 0:
|
|
73
|
+
return {}
|
|
74
|
+
|
|
75
|
+
logits = logits[valid_indices]
|
|
76
|
+
targets = targets[valid_indices]
|
|
77
|
+
|
|
78
|
+
preds = logits.argmax(dim=-1)
|
|
79
|
+
count = targets.numel()
|
|
80
|
+
corr = (preds == targets).sum().item()
|
|
81
|
+
|
|
82
|
+
logging_output["correct"] = corr
|
|
83
|
+
logging_output["count"] = count
|
|
84
|
+
|
|
85
|
+
# report aucs only in eval mode
|
|
86
|
+
if self.report_auc and not self.training:
|
|
87
|
+
probs = torch.sigmoid(logits).view(-1)
|
|
88
|
+
targets = F.one_hot(targets, logits.size(-1)).float().view(-1)
|
|
89
|
+
|
|
90
|
+
logging_output["_y_true"] = targets.cpu().numpy()
|
|
91
|
+
logging_output["_y_score"] = probs.cpu().numpy()
|
|
92
|
+
|
|
93
|
+
return logging_output
|
|
94
|
+
|
|
95
|
+
# def forward(self, model, sample):
|
|
96
|
+
# net_output = model(**sample['net_input'])
|
|
97
|
+
# if isinstance(model, DistributedDataParallel):
|
|
98
|
+
# logits = model.module.get_outputs(
|
|
99
|
+
# net_output,
|
|
100
|
+
# task=self.args.train_task,
|
|
101
|
+
# normalize=False
|
|
102
|
+
# )
|
|
103
|
+
# targets = model.module.get_targets(sample, net_output, self.args.train_task)
|
|
104
|
+
# else:
|
|
105
|
+
# logits = model.get_outputs(
|
|
106
|
+
# net_output,
|
|
107
|
+
# task=self.args.train_task,
|
|
108
|
+
# normalize=False
|
|
109
|
+
# )
|
|
110
|
+
# targets = model.get_targets(sample, net_output, self.args.train_task)
|
|
111
|
+
|
|
112
|
+
# loss_dict = {}
|
|
113
|
+
# logging_output = {}
|
|
114
|
+
|
|
115
|
+
# if self.args.train_task == 'pretrain' and self.args.pretrain_task in ['mlm', 'spanmlm']:
|
|
116
|
+
# B, S= targets['input_label'].shape
|
|
117
|
+
# for victim in self.args.mask_list:
|
|
118
|
+
# loss = F.cross_entropy(
|
|
119
|
+
# logits[victim+'_ids'].view(B*S, -1),
|
|
120
|
+
# targets[victim+'_label'].view(-1)
|
|
121
|
+
# )
|
|
122
|
+
# loss_dict[victim+'_loss'] = loss
|
|
123
|
+
|
|
124
|
+
# with torch.no_grad():
|
|
125
|
+
# preds = torch.argmax(logits[victim+'_ids'], dim=-1).view(-1).detach().cpu()
|
|
126
|
+
# target_label = targets[victim+'_label'].view(-1).detach().cpu()
|
|
127
|
+
# mask_idcs = (target_label != -100) & (target_label != 0)
|
|
128
|
+
# total = mask_idcs.sum()
|
|
129
|
+
# correct = (preds[mask_idcs] == target_label[mask_idcs]).sum().float()
|
|
130
|
+
|
|
131
|
+
# logging_output[victim+'_correct'] = correct
|
|
132
|
+
# logging_output[victim+'_total'] = total
|
|
133
|
+
|
|
134
|
+
# loss = sum(loss_dict.values())
|
|
135
|
+
# sample_size = len(sample)
|
|
136
|
+
# logging_output['loss'] = loss.item()
|
|
137
|
+
# logging_output['sample_size'] = sample_size
|
|
138
|
+
|
|
139
|
+
# elif self.args.train_task in ['finetune', 'scratch']:
|
|
140
|
+
|
|
141
|
+
# sample_size = len(targets)
|
|
142
|
+
# loss = F.cross_entropy(
|
|
143
|
+
# logits, F.one_hot(
|
|
144
|
+
# targets.long(),
|
|
145
|
+
# self.multi_label_dict[self.args.pred_src][self.args.pred_target]
|
|
146
|
+
# ).float().to(logits.device),
|
|
147
|
+
# reduction=self.ce_reduction_mode
|
|
148
|
+
# )
|
|
149
|
+
|
|
150
|
+
# logging_output['loss'] = loss.item()
|
|
151
|
+
# logging_output['sample_size'] = sample_size
|
|
152
|
+
|
|
153
|
+
# with torch.no_grad():
|
|
154
|
+
# probs = torch.sigmoid(logits).view(-1).detach()
|
|
155
|
+
# targets = self.mlb.transform(np.expand_dims(targets.view(-1).cpu(), axis=1)).flatten()
|
|
156
|
+
|
|
157
|
+
# logging_output["_y_true"] = targets
|
|
158
|
+
# logging_output["_y_score"] = probs.cpu().numpy()
|
|
159
|
+
|
|
160
|
+
# return loss, sample_size, logging_output
|
|
161
|
+
|
|
162
|
+
@staticmethod
|
|
163
|
+
def reduce_metrics(logging_outputs: List[Dict[str, Any]], prefix: str = None) -> None:
|
|
164
|
+
"""Aggregate logging outputs from data parallel training."""
|
|
165
|
+
if prefix is None:
|
|
166
|
+
prefix = ""
|
|
167
|
+
elif prefix is not None and not prefix.endswith("_"):
|
|
168
|
+
prefix = prefix + "_"
|
|
169
|
+
|
|
170
|
+
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
|
|
171
|
+
|
|
172
|
+
sample_size = utils.item(sum(log.get("sample_size", 0) for log in logging_outputs))
|
|
173
|
+
|
|
174
|
+
metrics.log_scalar(f"{prefix}loss", loss_sum / (sample_size or 1) / math.log(2), sample_size, round=3)
|
|
175
|
+
|
|
176
|
+
if "_y_true" in logging_outputs[0] and "_y_score" in logging_outputs[0]:
|
|
177
|
+
y_true = np.concatenate([log.get("_y_true", []) for log in logging_outputs])
|
|
178
|
+
y_score = np.concatenate([log.get("_y_score", []) for log in logging_outputs])
|
|
179
|
+
|
|
180
|
+
metrics.log_custom(meters.AUCMeter, f"_{prefix}auc", y_score, y_true)
|
|
181
|
+
|
|
182
|
+
correct = sum(log.get("correct", 0) for log in logging_outputs)
|
|
183
|
+
metrics.log_scalar(f"_{prefix}correct", correct)
|
|
184
|
+
|
|
185
|
+
total = sum(log.get("count", 0) for log in logging_outputs)
|
|
186
|
+
metrics.log_scalar(f"_{prefix}total", total)
|
|
187
|
+
|
|
188
|
+
if total > 0:
|
|
189
|
+
metrics.log_derived(
|
|
190
|
+
f"{prefix}accuracy",
|
|
191
|
+
lambda meters: safe_round(meters[f"_{prefix}correct"].sum / meters[f"_{prefix}total"].sum, 5)
|
|
192
|
+
if meters[f"_{prefix}total"].sum > 0
|
|
193
|
+
else float("nan"),
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
def post_validate(self, stats, agg, **kwargs):
|
|
197
|
+
for key in agg.keys():
|
|
198
|
+
if key.startswith("_") and key.endswith("auc"):
|
|
199
|
+
stats[key[1:-3] + "auroc"] = agg[key].auroc
|
|
200
|
+
stats[key[1:-3] + "auprc"] = agg[key].auprc
|
|
201
|
+
|
|
202
|
+
return stats
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import re
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from typing import Any, Dict, List, Optional
|
|
6
|
+
|
|
7
|
+
import genhpf.utils.utils as utils
|
|
8
|
+
from genhpf.criterions import BaseCriterion, register_criterion
|
|
9
|
+
from genhpf.criterions.criterion import CriterionConfig
|
|
10
|
+
from genhpf.loggings import metrics
|
|
11
|
+
from genhpf.models.genhpf import GenHPF
|
|
12
|
+
|
|
13
|
+
from . import build_criterion
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class MultiTaskCriterionConfig(CriterionConfig):
|
|
18
|
+
task_loss_weights: Optional[List[float]] = field(
|
|
19
|
+
default=None,
|
|
20
|
+
metadata={
|
|
21
|
+
"help": "weights for each loss term. if given, has to be a float list of size " "n_criterions"
|
|
22
|
+
},
|
|
23
|
+
)
|
|
24
|
+
args: Any = field(
|
|
25
|
+
default=None,
|
|
26
|
+
metadata={
|
|
27
|
+
"help": "configurations for each criterion where the name of each argument should "
|
|
28
|
+
"match with the corresponding task name."
|
|
29
|
+
},
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@register_criterion("multi_task_criterion", dataclass=MultiTaskCriterionConfig)
|
|
34
|
+
class MultiTaskCriterion(BaseCriterion):
|
|
35
|
+
def __init__(self, cfg: MultiTaskCriterionConfig):
|
|
36
|
+
super().__init__(cfg)
|
|
37
|
+
|
|
38
|
+
criterions = {}
|
|
39
|
+
for task_name in self.task_names:
|
|
40
|
+
criterion_cfg = getattr(cfg.args, task_name)
|
|
41
|
+
criterions[task_name] = build_criterion(criterion_cfg)
|
|
42
|
+
self.criterions = criterions
|
|
43
|
+
|
|
44
|
+
if cfg.task_loss_weights is None:
|
|
45
|
+
self.task_loss_weights = [1.0] * len(criterions)
|
|
46
|
+
else:
|
|
47
|
+
self.task_loss_weights = cfg.task_loss_weights
|
|
48
|
+
|
|
49
|
+
def forward(self, model: GenHPF, sample, return_net_output=False):
|
|
50
|
+
net_output = model(**sample["net_input"])
|
|
51
|
+
logits = model.get_logits(sample, net_output)
|
|
52
|
+
targets = model.get_targets(sample, net_output)
|
|
53
|
+
|
|
54
|
+
if not isinstance(logits, dict):
|
|
55
|
+
logits = {self.task_names[0]: logits}
|
|
56
|
+
if not isinstance(targets, dict):
|
|
57
|
+
targets = {self.task_names[0]: targets}
|
|
58
|
+
|
|
59
|
+
if len(logits) != len(self.task_names) or len(targets) != len(self.task_names):
|
|
60
|
+
raise ValueError(
|
|
61
|
+
"number of logits and targets should be equal to the number of tasks. "
|
|
62
|
+
f"got {len(logits)} logits and {len(targets)} targets for "
|
|
63
|
+
f"{len(self.task_names)} tasks"
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
loss = 0.0
|
|
67
|
+
logging_outputs = dict()
|
|
68
|
+
for i, task_name in enumerate(self.task_names):
|
|
69
|
+
criterion = self.criterions[task_name]
|
|
70
|
+
assert (
|
|
71
|
+
task_name in logits and task_name in targets
|
|
72
|
+
), f"task name {task_name} not found in logits or targets"
|
|
73
|
+
task_logits = logits[task_name]
|
|
74
|
+
task_targets = targets[task_name]
|
|
75
|
+
task_loss, task_losses_to_log = criterion.compute_loss(
|
|
76
|
+
logits=task_logits, targets=task_targets, sample=sample, net_output=net_output, model=model
|
|
77
|
+
)
|
|
78
|
+
task_loss *= self.task_loss_weights[i]
|
|
79
|
+
sample_size = criterion.get_sample_size(sample, task_targets)
|
|
80
|
+
|
|
81
|
+
logging_outputs[f"<{task_name}>_criterion_cls"] = criterion.__class__
|
|
82
|
+
if len(task_losses_to_log) > 1:
|
|
83
|
+
logging_outputs[f"{task_name}_loss"] = task_loss.item()
|
|
84
|
+
for j, l in enumerate(task_losses_to_log):
|
|
85
|
+
logging_outputs[f"<{task_name}>_loss_{j}"] = l
|
|
86
|
+
else:
|
|
87
|
+
logging_outputs[f"<{task_name}>_loss"] = task_losses_to_log[0]
|
|
88
|
+
logging_outputs[f"<{task_name}>_sample_size"] = sample_size
|
|
89
|
+
|
|
90
|
+
task_logging_output = criterion.get_logging_outputs({}, task_logits, task_targets, sample)
|
|
91
|
+
for log, value in task_logging_output.items():
|
|
92
|
+
if log.startswith("_"):
|
|
93
|
+
log = log[1:]
|
|
94
|
+
logging_outputs[f"_<{task_name}>_{log}"] = value
|
|
95
|
+
else:
|
|
96
|
+
logging_outputs[f"<{task_name}>_{log}"] = value
|
|
97
|
+
|
|
98
|
+
# divide task loss by the sample size beforehand to handle different sample
|
|
99
|
+
# sizes for multiple criterions
|
|
100
|
+
loss += task_loss / logging_outputs[f"<{task_name}>_sample_size"]
|
|
101
|
+
|
|
102
|
+
# manipulate sample_size to be 1 to avoid double-dividing gradients in optimizer later
|
|
103
|
+
sample_size = 1
|
|
104
|
+
|
|
105
|
+
if return_net_output:
|
|
106
|
+
return loss, sample_size, logging_outputs, net_output
|
|
107
|
+
else:
|
|
108
|
+
return loss, sample_size, logging_outputs
|
|
109
|
+
|
|
110
|
+
@staticmethod
|
|
111
|
+
def reduce_metrics(logging_outputs: List[Dict[str, Any]]) -> None:
|
|
112
|
+
log_keys = logging_outputs[0].keys()
|
|
113
|
+
|
|
114
|
+
grouped_log_keys = defaultdict(list)
|
|
115
|
+
for lk in log_keys:
|
|
116
|
+
group = re.search(r"\<.*\>", lk)
|
|
117
|
+
offset = group.end() + 1
|
|
118
|
+
group = group.group()[1:-1]
|
|
119
|
+
key = lk[offset:]
|
|
120
|
+
if lk.startswith("_"):
|
|
121
|
+
key = "_" + key
|
|
122
|
+
grouped_log_keys[group].append(key)
|
|
123
|
+
|
|
124
|
+
total_loss = 0
|
|
125
|
+
for group, log_keys in grouped_log_keys.items():
|
|
126
|
+
criterion_cls = logging_outputs[0][f"<{group}>_criterion_cls"]
|
|
127
|
+
logging_output = []
|
|
128
|
+
for log in logging_outputs:
|
|
129
|
+
log_dict = {}
|
|
130
|
+
for log_key in set(log_keys) - {"criterion_cls"}:
|
|
131
|
+
if log_key.startswith("_") and f"_<{group}>{log_key}" in log:
|
|
132
|
+
log_dict[log_key] = log[f"_<{group}>{log_key}"]
|
|
133
|
+
elif f"<{group}>_{log_key}" in log:
|
|
134
|
+
log_dict[log_key] = log[f"<{group}>_{log_key}"]
|
|
135
|
+
logging_output.append(log_dict)
|
|
136
|
+
criterion_cls.reduce_metrics(logging_output, prefix=group)
|
|
137
|
+
|
|
138
|
+
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_output))
|
|
139
|
+
sample_size = utils.item(sum(log.get("sample_size", 0) for log in logging_output))
|
|
140
|
+
|
|
141
|
+
total_loss += loss_sum / (sample_size or 1) / math.log(2)
|
|
142
|
+
|
|
143
|
+
metrics.log_scalar("loss", total_loss, 1, round=3)
|
|
144
|
+
|
|
145
|
+
def post_validate(self, stats, agg, **kwargs):
|
|
146
|
+
task_agg = {}
|
|
147
|
+
for key in agg:
|
|
148
|
+
for task_name in self.task_names:
|
|
149
|
+
if key.startswith(task_name) or key[1:].startswith(task_name):
|
|
150
|
+
if task_name not in task_agg:
|
|
151
|
+
task_agg[task_name] = {}
|
|
152
|
+
task_agg[task_name][key] = agg[key]
|
|
153
|
+
break
|
|
154
|
+
|
|
155
|
+
for task_name, task_agg in task_agg.items():
|
|
156
|
+
if hasattr(self.criterions[task_name], "post_validate"):
|
|
157
|
+
stats = self.criterions[task_name].post_validate(stats, task_agg, **kwargs)
|
|
158
|
+
|
|
159
|
+
for key in list(stats.keys()):
|
|
160
|
+
for task_name in self.task_names:
|
|
161
|
+
if key.startswith(task_name):
|
|
162
|
+
stat_key = key[len(task_name) + 1 :]
|
|
163
|
+
if f"avg_{stat_key}" not in stats:
|
|
164
|
+
stats[f"avg_{stat_key}"] = []
|
|
165
|
+
stats[f"avg_{stat_key}"].append(stats[key])
|
|
166
|
+
break
|
|
167
|
+
|
|
168
|
+
for key in list(stats.keys()):
|
|
169
|
+
if key.startswith("avg_"):
|
|
170
|
+
stats[key] = sum(stats[key]) / len(stats[key])
|
|
171
|
+
return stats
|
|
172
|
+
|
|
173
|
+
def eval(self):
|
|
174
|
+
super().eval()
|
|
175
|
+
for criterion in self.criterions.values():
|
|
176
|
+
criterion.eval()
|
|
177
|
+
return self
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import List, Tuple
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
|
|
8
|
+
import genhpf.utils.utils as utils
|
|
9
|
+
from genhpf.criterions import BaseCriterion, register_criterion
|
|
10
|
+
from genhpf.criterions.criterion import CriterionConfig
|
|
11
|
+
from genhpf.loggings import metrics
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class SimCLRCriterionConfig(CriterionConfig):
|
|
16
|
+
temp: float = field(default=0.1, metadata={"help": "temperature to divide logits by"})
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@register_criterion("simclr_criterion", dataclass=SimCLRCriterionConfig)
|
|
20
|
+
class SimCLRCriterion(BaseCriterion):
|
|
21
|
+
def __init__(self, cfg: SimCLRCriterionConfig):
|
|
22
|
+
super().__init__(cfg)
|
|
23
|
+
|
|
24
|
+
self.temp = cfg.temp
|
|
25
|
+
|
|
26
|
+
def compute_loss(
|
|
27
|
+
self, logits: torch.Tensor, targets: torch.Tensor = None, sample=None, net_output=None, model=None
|
|
28
|
+
) -> Tuple[torch.Tensor, List[float]]:
|
|
29
|
+
"""Compute the loss given the logits and targets from the model."""
|
|
30
|
+
logits = F.normalize(logits, dim=1) # normalize logits
|
|
31
|
+
|
|
32
|
+
bsz = int(logits.shape[0] / 2)
|
|
33
|
+
|
|
34
|
+
mask = 1 - torch.eye(bsz * 2, dtype=torch.uint8).to(logits.device)
|
|
35
|
+
pos_ind = (
|
|
36
|
+
torch.arange(bsz * 2).to(logits.device),
|
|
37
|
+
2
|
|
38
|
+
* torch.arange(bsz, dtype=torch.long)
|
|
39
|
+
.unsqueeze(1)
|
|
40
|
+
.repeat(1, 2)
|
|
41
|
+
.view(-1, 1)
|
|
42
|
+
.squeeze()
|
|
43
|
+
.to(logits.device),
|
|
44
|
+
)
|
|
45
|
+
neg_mask = torch.ones((bsz * 2, bsz * 2 - 1), dtype=torch.uint8).to(logits.device)
|
|
46
|
+
neg_mask[pos_ind] = 0
|
|
47
|
+
|
|
48
|
+
# Cosine similarity computation
|
|
49
|
+
sim_matrix = torch.matmul(logits, logits.T) # cosine similarity computation
|
|
50
|
+
|
|
51
|
+
# Eliminate similarity between same view
|
|
52
|
+
sim_matrix = torch.masked_select(sim_matrix, mask.bool()).view(sim_matrix.size(0), -1)
|
|
53
|
+
|
|
54
|
+
positives = sim_matrix[pos_ind].unsqueeze(1)
|
|
55
|
+
negatives = torch.masked_select(sim_matrix, neg_mask.bool()).view(sim_matrix.size(0), -1)
|
|
56
|
+
|
|
57
|
+
logits = torch.cat((positives, negatives), dim=1)
|
|
58
|
+
logits /= self.temp
|
|
59
|
+
|
|
60
|
+
target = torch.zeros((logits.size(0),), dtype=torch.long).to(logits.device)
|
|
61
|
+
|
|
62
|
+
loss = F.cross_entropy(logits, target, reduction="sum")
|
|
63
|
+
|
|
64
|
+
return loss, [loss.detach().item()]
|
|
65
|
+
|
|
66
|
+
def get_sample_size(self, sample, targets: torch.Tensor = None) -> int:
|
|
67
|
+
return sample["net_input"]["input_ids"].size(0)
|
|
68
|
+
|
|
69
|
+
def get_logging_outputs(self, logging_output, logits, target, sample=None, net_output=None):
|
|
70
|
+
return logging_output
|
|
71
|
+
|
|
72
|
+
@staticmethod
|
|
73
|
+
def reduce_metrics(logging_outputs, prefix: str = None) -> None:
|
|
74
|
+
"""Aggregate logging outputs from data parallel training."""
|
|
75
|
+
if prefix is None:
|
|
76
|
+
prefix = ""
|
|
77
|
+
elif prefix is not None and not prefix.endswith("_"):
|
|
78
|
+
prefix = prefix + "_"
|
|
79
|
+
|
|
80
|
+
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
|
|
81
|
+
|
|
82
|
+
sample_size = utils.item(sum(log.get("sample_size", 0) for log in logging_outputs))
|
|
83
|
+
|
|
84
|
+
metrics.log_scalar(f"{prefix}loss", loss_sum / (sample_size or 1) / math.log(2), sample_size, round=3)
|