genhpf 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of genhpf might be problematic. Click here for more details.

Files changed (67) hide show
  1. genhpf/__init__.py +9 -0
  2. genhpf/configs/__init__.py +23 -0
  3. genhpf/configs/config.yaml +8 -0
  4. genhpf/configs/configs.py +240 -0
  5. genhpf/configs/constants.py +29 -0
  6. genhpf/configs/initialize.py +58 -0
  7. genhpf/configs/utils.py +29 -0
  8. genhpf/criterions/__init__.py +74 -0
  9. genhpf/criterions/binary_cross_entropy.py +114 -0
  10. genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
  11. genhpf/criterions/criterion.py +87 -0
  12. genhpf/criterions/cross_entropy.py +202 -0
  13. genhpf/criterions/multi_task_criterion.py +177 -0
  14. genhpf/criterions/simclr_criterion.py +84 -0
  15. genhpf/criterions/wav2vec2_criterion.py +130 -0
  16. genhpf/datasets/__init__.py +84 -0
  17. genhpf/datasets/dataset.py +109 -0
  18. genhpf/datasets/genhpf_dataset.py +451 -0
  19. genhpf/datasets/meds_dataset.py +232 -0
  20. genhpf/loggings/__init__.py +0 -0
  21. genhpf/loggings/meters.py +374 -0
  22. genhpf/loggings/metrics.py +155 -0
  23. genhpf/loggings/progress_bar.py +445 -0
  24. genhpf/models/__init__.py +73 -0
  25. genhpf/models/genhpf.py +233 -0
  26. genhpf/models/genhpf_mlm.py +64 -0
  27. genhpf/models/genhpf_predictor.py +73 -0
  28. genhpf/models/genhpf_simclr.py +58 -0
  29. genhpf/models/genhpf_wav2vec2.py +304 -0
  30. genhpf/modules/__init__.py +15 -0
  31. genhpf/modules/gather_layer.py +23 -0
  32. genhpf/modules/grad_multiply.py +12 -0
  33. genhpf/modules/gumbel_vector_quantizer.py +204 -0
  34. genhpf/modules/identity_layer.py +8 -0
  35. genhpf/modules/layer_norm.py +27 -0
  36. genhpf/modules/positional_encoding.py +24 -0
  37. genhpf/scripts/__init__.py +0 -0
  38. genhpf/scripts/preprocess/__init__.py +0 -0
  39. genhpf/scripts/preprocess/genhpf/README.md +75 -0
  40. genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
  41. genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
  42. genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
  43. genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
  44. genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
  45. genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
  46. genhpf/scripts/preprocess/genhpf/main.py +174 -0
  47. genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
  48. genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
  49. genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
  50. genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
  51. genhpf/scripts/preprocess/manifest.py +83 -0
  52. genhpf/scripts/preprocess/preprocess_meds.py +584 -0
  53. genhpf/scripts/test.py +261 -0
  54. genhpf/scripts/train.py +350 -0
  55. genhpf/trainer.py +370 -0
  56. genhpf/utils/checkpoint_utils.py +171 -0
  57. genhpf/utils/data_utils.py +130 -0
  58. genhpf/utils/distributed_utils.py +497 -0
  59. genhpf/utils/file_io.py +170 -0
  60. genhpf/utils/pdb.py +38 -0
  61. genhpf/utils/utils.py +204 -0
  62. genhpf-1.0.0.dist-info/LICENSE +21 -0
  63. genhpf-1.0.0.dist-info/METADATA +197 -0
  64. genhpf-1.0.0.dist-info/RECORD +67 -0
  65. genhpf-1.0.0.dist-info/WHEEL +5 -0
  66. genhpf-1.0.0.dist-info/entry_points.txt +6 -0
  67. genhpf-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,115 @@
1
+ import math
2
+ from dataclasses import dataclass, field
3
+ from typing import Any, Dict, List, Tuple
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ import genhpf.utils.utils as utils
10
+ from genhpf.criterions import BaseCriterion, register_criterion
11
+ from genhpf.criterions.criterion import CriterionConfig
12
+ from genhpf.loggings import meters, metrics
13
+ from genhpf.loggings.meters import safe_round
14
+
15
+
16
+ @dataclass
17
+ class BinaryCrossEntropyWithLogitsConfig(CriterionConfig):
18
+ threshold: float = field(default=0.5, metadata={"help": "threshold value for binary classification"})
19
+
20
+
21
+ @register_criterion("binary_cross_entropy_with_logits", dataclass=BinaryCrossEntropyWithLogitsConfig)
22
+ class BinaryCrossEntropyWithLogits(BaseCriterion):
23
+ def __init__(self, cfg: BinaryCrossEntropyWithLogitsConfig):
24
+ super().__init__(cfg)
25
+
26
+ if self.task_names is not None and len(self.task_names) > 1:
27
+ raise ValueError(
28
+ "binary_cross_entropy_with_logits only supports single task training."
29
+ " if you want to train multiple tasks, use multi_task_criterion instead."
30
+ )
31
+
32
+ self.threshold = cfg.threshold
33
+
34
+ def compute_loss(
35
+ self, logits: torch.Tensor, targets: torch.Tensor, sample=None, net_output=None, model=None
36
+ ) -> Tuple[torch.Tensor, List[float]]:
37
+ assert (
38
+ logits.size() == targets.size()
39
+ ), f"logits and targets must have the same size: {logits.size()} vs {targets.size()}"
40
+ targets = targets.float()
41
+ loss = F.binary_cross_entropy_with_logits(input=logits, target=targets, reduction="sum")
42
+ return loss, [loss.detach().item()]
43
+
44
+ def get_sample_size(self, sample, targets: torch.Tensor) -> int:
45
+ if "sample_size" in sample:
46
+ sample_size = sample["sample_size"]
47
+ else:
48
+ sample_size = targets.numel()
49
+ return sample_size
50
+
51
+ def get_logging_outputs(
52
+ self, logging_output, logits: torch.Tensor, targets: torch.Tensor, sample=None
53
+ ) -> List[Dict[str, Any]]:
54
+ with torch.no_grad():
55
+ probs = torch.sigmoid(logits)
56
+ outputs = probs > self.threshold
57
+
58
+ if probs.numel() == 0:
59
+ corr = 0
60
+ count = 0
61
+ else:
62
+ count = float(probs.numel())
63
+ corr = (outputs == targets).sum().item()
64
+
65
+ logging_output["correct"] = corr
66
+ logging_output["count"] = count
67
+
68
+ # report aucs only in eval mode
69
+ if not self.training:
70
+ logging_output["_y_true"] = targets.cpu().numpy()
71
+ logging_output["_y_score"] = probs.cpu().numpy()
72
+
73
+ return logging_output
74
+
75
+ @staticmethod
76
+ def reduce_metrics(logging_outputs: List[Dict[str, Any]], prefix: str = None) -> None:
77
+ """Aggregate logging outputs from data parallel training."""
78
+ if prefix is None:
79
+ prefix = ""
80
+ elif prefix is not None and not prefix.endswith("_"):
81
+ prefix = prefix + "_"
82
+
83
+ loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
84
+
85
+ sample_size = utils.item(sum(log.get("sample_size", 0) for log in logging_outputs))
86
+
87
+ metrics.log_scalar(f"{prefix}loss", loss_sum / (sample_size or 1) / math.log(2), sample_size, round=3)
88
+
89
+ if "_y_true" in logging_outputs[0] and "_y_score" in logging_outputs[0]:
90
+ y_true = np.concatenate([log.get("_y_true", []) for log in logging_outputs])
91
+ y_score = np.concatenate([log.get("_y_score", []) for log in logging_outputs])
92
+
93
+ metrics.log_custom(meters.AUCMeter, f"_{prefix}auc", y_score, y_true)
94
+
95
+ correct = sum(log.get("correct", 0) for log in logging_outputs)
96
+ metrics.log_scalar(f"_{prefix}correct", correct)
97
+
98
+ total = sum(log.get("count", 0) for log in logging_outputs)
99
+ metrics.log_scalar(f"_{prefix}total", total)
100
+
101
+ if total > 0:
102
+ metrics.log_derived(
103
+ f"{prefix}accuracy",
104
+ lambda meters: safe_round(meters[f"_{prefix}correct"].sum / meters[f"_{prefix}total"].sum, 5)
105
+ if meters[f"_{prefix}total"].sum > 0
106
+ else float("nan"),
107
+ )
108
+
109
+ def post_validate(self, stats, agg, **kwargs):
110
+ for key in agg.keys():
111
+ if key.startswith("_") and key.endswith("auc"):
112
+ stats[key[1:-3] + "auroc"] = agg[key].auroc
113
+ stats[key[1:-3] + "auprc"] = agg[key].auprc
114
+
115
+ return stats
@@ -0,0 +1,87 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import Any, Dict, List, Optional, Tuple
3
+
4
+ import torch
5
+ from torch.nn.modules.loss import _Loss
6
+
7
+ from genhpf.configs import BaseConfig
8
+ from genhpf.models.genhpf import GenHPF
9
+
10
+
11
+ @dataclass
12
+ class CriterionConfig(BaseConfig):
13
+ task_names: Optional[List[str]] = field(
14
+ default=None, metadata={"help": "a list of task names for multi-task learning"}
15
+ )
16
+ num_labels: Optional[List[int]] = field(
17
+ default=None, metadata={"help": "a list of number of labels for each task"}
18
+ )
19
+
20
+
21
+ class BaseCriterion(_Loss):
22
+ def __init__(self, cfg: CriterionConfig):
23
+ super().__init__()
24
+ self.cfg = cfg
25
+
26
+ self.task_names = cfg.task_names
27
+ self.num_labels = cfg.num_labels
28
+
29
+ @classmethod
30
+ def build_criterion(cls, cfg: CriterionConfig):
31
+ """Construct a new criterion instance."""
32
+ return cls(cfg)
33
+
34
+ def compute_loss(
35
+ self, logits: torch.Tensor, targets: torch.Tensor, sample=None, net_output=None, model=None
36
+ ) -> Tuple[torch.Tensor, List[float]]:
37
+ """Compute the loss given the logits and targets from the model."""
38
+ raise NotImplementedError("Criterion must implement the `compute_loss` method")
39
+
40
+ def get_sample_size(self, sample, targets: torch.Tensor) -> int:
41
+ """Get the sample size, which is used as the denominator for the gradient."""
42
+ raise NotImplementedError("Criterion must implement the `get_sample_size` method")
43
+
44
+ def get_logging_outputs(
45
+ self, logging_output, logits: torch.Tensor, targets: torch.Tensor, sample=None
46
+ ) -> List[Dict[str, Any]]:
47
+ """
48
+ Get the logging output to display while training
49
+ """
50
+ raise NotImplementedError("Criterion must implement the `get_logging_outputs` method")
51
+
52
+ def forward(self, model: GenHPF, sample, return_net_output=False):
53
+ """Compute the loss for the given sample.
54
+
55
+ Returns a tuple with three elements:
56
+ 1. the loss
57
+ 2. the sample size, which is used as the denominator for the gradient
58
+ 3. logging outputs to display while training
59
+ """
60
+ net_output = model(**sample["net_input"])
61
+ logits = model.get_logits(sample, net_output)
62
+ targets = model.get_targets(sample, net_output)
63
+
64
+ loss, losses_to_log = self.compute_loss(
65
+ logits, targets, sample=sample, net_output=net_output, model=model
66
+ )
67
+ sample_size = self.get_sample_size(sample, targets)
68
+
69
+ logging_output = {}
70
+ if len(losses_to_log) > 1:
71
+ logging_output["loss"] = loss.item()
72
+ for i, l in enumerate(losses_to_log):
73
+ logging_output[f"loss_{i}"] = l
74
+ else:
75
+ logging_output["loss"] = losses_to_log[0]
76
+ logging_output["sample_size"] = sample_size
77
+ logging_output = self.get_logging_outputs(logging_output, logits, targets, sample)
78
+
79
+ if return_net_output:
80
+ return loss, sample_size, logging_output, net_output
81
+ else:
82
+ return loss, sample_size, logging_output
83
+
84
+ @staticmethod
85
+ def reduce_metrics(stats: Dict[str, Any], prefix: str = None) -> None:
86
+ """Aggregate logging outputs from data parallel training."""
87
+ raise NotImplementedError
@@ -0,0 +1,202 @@
1
+ import math
2
+ from dataclasses import dataclass, field
3
+ from typing import Any, Dict, List, Tuple
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from omegaconf import II
9
+
10
+ import genhpf.utils.utils as utils
11
+ from genhpf.criterions import BaseCriterion, register_criterion
12
+ from genhpf.criterions.criterion import CriterionConfig
13
+ from genhpf.loggings import meters, metrics
14
+ from genhpf.loggings.meters import safe_round
15
+
16
+
17
+ @dataclass
18
+ class CrossEntropyConfig(CriterionConfig):
19
+ report_auc: bool = field(
20
+ default=False,
21
+ metadata={
22
+ "help": "whether to report auc. note that this is only available in eval mode and "
23
+ "can cause memory and performance issues if enabled."
24
+ },
25
+ )
26
+ ignore_index: int = II("dataset.ignore_index")
27
+
28
+
29
+ @register_criterion("cross_entropy", dataclass=CrossEntropyConfig)
30
+ class CrossEntropy(BaseCriterion):
31
+ def __init__(self, cfg: CrossEntropyConfig):
32
+ super().__init__(cfg)
33
+
34
+ if self.task_names is not None and len(self.task_names) > 1:
35
+ raise ValueError(
36
+ "cross_entropy only supports single task training. if you want to train multiple"
37
+ " tasks, use multi_task_criterion instead."
38
+ )
39
+
40
+ self.report_auc = cfg.report_auc
41
+ self.ignore_index = cfg.ignore_index
42
+
43
+ def compute_loss(
44
+ self, logits: torch.Tensor, targets: torch.Tensor, sample=None, net_output=None, model=None
45
+ ) -> Tuple[torch.Tensor, List[float]]:
46
+ """Compute the loss given the logits and targets from the model."""
47
+ logits = logits.view(-1, logits.size(-1))
48
+ targets = targets.view(-1).long()
49
+
50
+ if torch.all(targets == self.ignore_index):
51
+ return logits.new_tensor(0.0), [0.0]
52
+
53
+ loss = F.cross_entropy(logits, targets, reduction="sum", ignore_index=self.ignore_index)
54
+
55
+ return loss, [loss.detach().item()]
56
+
57
+ def get_sample_size(self, sample, targets: torch.Tensor) -> int:
58
+ if "sample_size" in sample:
59
+ sample_size = sample["sample_size"]
60
+ else:
61
+ sample_size = targets.numel()
62
+ return sample_size
63
+
64
+ def get_logging_outputs(
65
+ self, logging_output, logits: torch.Tensor, targets: torch.Tensor, sample=None
66
+ ) -> List[Dict[str, Any]]:
67
+ with torch.no_grad():
68
+ logits = logits.view(-1, logits.size(-1))
69
+ targets = targets.view(-1).long()
70
+
71
+ valid_indices = torch.where(targets != self.ignore_index)
72
+ if len(valid_indices[0]) == 0:
73
+ return {}
74
+
75
+ logits = logits[valid_indices]
76
+ targets = targets[valid_indices]
77
+
78
+ preds = logits.argmax(dim=-1)
79
+ count = targets.numel()
80
+ corr = (preds == targets).sum().item()
81
+
82
+ logging_output["correct"] = corr
83
+ logging_output["count"] = count
84
+
85
+ # report aucs only in eval mode
86
+ if self.report_auc and not self.training:
87
+ probs = torch.sigmoid(logits).view(-1)
88
+ targets = F.one_hot(targets, logits.size(-1)).float().view(-1)
89
+
90
+ logging_output["_y_true"] = targets.cpu().numpy()
91
+ logging_output["_y_score"] = probs.cpu().numpy()
92
+
93
+ return logging_output
94
+
95
+ # def forward(self, model, sample):
96
+ # net_output = model(**sample['net_input'])
97
+ # if isinstance(model, DistributedDataParallel):
98
+ # logits = model.module.get_outputs(
99
+ # net_output,
100
+ # task=self.args.train_task,
101
+ # normalize=False
102
+ # )
103
+ # targets = model.module.get_targets(sample, net_output, self.args.train_task)
104
+ # else:
105
+ # logits = model.get_outputs(
106
+ # net_output,
107
+ # task=self.args.train_task,
108
+ # normalize=False
109
+ # )
110
+ # targets = model.get_targets(sample, net_output, self.args.train_task)
111
+
112
+ # loss_dict = {}
113
+ # logging_output = {}
114
+
115
+ # if self.args.train_task == 'pretrain' and self.args.pretrain_task in ['mlm', 'spanmlm']:
116
+ # B, S= targets['input_label'].shape
117
+ # for victim in self.args.mask_list:
118
+ # loss = F.cross_entropy(
119
+ # logits[victim+'_ids'].view(B*S, -1),
120
+ # targets[victim+'_label'].view(-1)
121
+ # )
122
+ # loss_dict[victim+'_loss'] = loss
123
+
124
+ # with torch.no_grad():
125
+ # preds = torch.argmax(logits[victim+'_ids'], dim=-1).view(-1).detach().cpu()
126
+ # target_label = targets[victim+'_label'].view(-1).detach().cpu()
127
+ # mask_idcs = (target_label != -100) & (target_label != 0)
128
+ # total = mask_idcs.sum()
129
+ # correct = (preds[mask_idcs] == target_label[mask_idcs]).sum().float()
130
+
131
+ # logging_output[victim+'_correct'] = correct
132
+ # logging_output[victim+'_total'] = total
133
+
134
+ # loss = sum(loss_dict.values())
135
+ # sample_size = len(sample)
136
+ # logging_output['loss'] = loss.item()
137
+ # logging_output['sample_size'] = sample_size
138
+
139
+ # elif self.args.train_task in ['finetune', 'scratch']:
140
+
141
+ # sample_size = len(targets)
142
+ # loss = F.cross_entropy(
143
+ # logits, F.one_hot(
144
+ # targets.long(),
145
+ # self.multi_label_dict[self.args.pred_src][self.args.pred_target]
146
+ # ).float().to(logits.device),
147
+ # reduction=self.ce_reduction_mode
148
+ # )
149
+
150
+ # logging_output['loss'] = loss.item()
151
+ # logging_output['sample_size'] = sample_size
152
+
153
+ # with torch.no_grad():
154
+ # probs = torch.sigmoid(logits).view(-1).detach()
155
+ # targets = self.mlb.transform(np.expand_dims(targets.view(-1).cpu(), axis=1)).flatten()
156
+
157
+ # logging_output["_y_true"] = targets
158
+ # logging_output["_y_score"] = probs.cpu().numpy()
159
+
160
+ # return loss, sample_size, logging_output
161
+
162
+ @staticmethod
163
+ def reduce_metrics(logging_outputs: List[Dict[str, Any]], prefix: str = None) -> None:
164
+ """Aggregate logging outputs from data parallel training."""
165
+ if prefix is None:
166
+ prefix = ""
167
+ elif prefix is not None and not prefix.endswith("_"):
168
+ prefix = prefix + "_"
169
+
170
+ loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
171
+
172
+ sample_size = utils.item(sum(log.get("sample_size", 0) for log in logging_outputs))
173
+
174
+ metrics.log_scalar(f"{prefix}loss", loss_sum / (sample_size or 1) / math.log(2), sample_size, round=3)
175
+
176
+ if "_y_true" in logging_outputs[0] and "_y_score" in logging_outputs[0]:
177
+ y_true = np.concatenate([log.get("_y_true", []) for log in logging_outputs])
178
+ y_score = np.concatenate([log.get("_y_score", []) for log in logging_outputs])
179
+
180
+ metrics.log_custom(meters.AUCMeter, f"_{prefix}auc", y_score, y_true)
181
+
182
+ correct = sum(log.get("correct", 0) for log in logging_outputs)
183
+ metrics.log_scalar(f"_{prefix}correct", correct)
184
+
185
+ total = sum(log.get("count", 0) for log in logging_outputs)
186
+ metrics.log_scalar(f"_{prefix}total", total)
187
+
188
+ if total > 0:
189
+ metrics.log_derived(
190
+ f"{prefix}accuracy",
191
+ lambda meters: safe_round(meters[f"_{prefix}correct"].sum / meters[f"_{prefix}total"].sum, 5)
192
+ if meters[f"_{prefix}total"].sum > 0
193
+ else float("nan"),
194
+ )
195
+
196
+ def post_validate(self, stats, agg, **kwargs):
197
+ for key in agg.keys():
198
+ if key.startswith("_") and key.endswith("auc"):
199
+ stats[key[1:-3] + "auroc"] = agg[key].auroc
200
+ stats[key[1:-3] + "auprc"] = agg[key].auprc
201
+
202
+ return stats
@@ -0,0 +1,177 @@
1
+ import math
2
+ import re
3
+ from collections import defaultdict
4
+ from dataclasses import dataclass, field
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ import genhpf.utils.utils as utils
8
+ from genhpf.criterions import BaseCriterion, register_criterion
9
+ from genhpf.criterions.criterion import CriterionConfig
10
+ from genhpf.loggings import metrics
11
+ from genhpf.models.genhpf import GenHPF
12
+
13
+ from . import build_criterion
14
+
15
+
16
+ @dataclass
17
+ class MultiTaskCriterionConfig(CriterionConfig):
18
+ task_loss_weights: Optional[List[float]] = field(
19
+ default=None,
20
+ metadata={
21
+ "help": "weights for each loss term. if given, has to be a float list of size " "n_criterions"
22
+ },
23
+ )
24
+ args: Any = field(
25
+ default=None,
26
+ metadata={
27
+ "help": "configurations for each criterion where the name of each argument should "
28
+ "match with the corresponding task name."
29
+ },
30
+ )
31
+
32
+
33
+ @register_criterion("multi_task_criterion", dataclass=MultiTaskCriterionConfig)
34
+ class MultiTaskCriterion(BaseCriterion):
35
+ def __init__(self, cfg: MultiTaskCriterionConfig):
36
+ super().__init__(cfg)
37
+
38
+ criterions = {}
39
+ for task_name in self.task_names:
40
+ criterion_cfg = getattr(cfg.args, task_name)
41
+ criterions[task_name] = build_criterion(criterion_cfg)
42
+ self.criterions = criterions
43
+
44
+ if cfg.task_loss_weights is None:
45
+ self.task_loss_weights = [1.0] * len(criterions)
46
+ else:
47
+ self.task_loss_weights = cfg.task_loss_weights
48
+
49
+ def forward(self, model: GenHPF, sample, return_net_output=False):
50
+ net_output = model(**sample["net_input"])
51
+ logits = model.get_logits(sample, net_output)
52
+ targets = model.get_targets(sample, net_output)
53
+
54
+ if not isinstance(logits, dict):
55
+ logits = {self.task_names[0]: logits}
56
+ if not isinstance(targets, dict):
57
+ targets = {self.task_names[0]: targets}
58
+
59
+ if len(logits) != len(self.task_names) or len(targets) != len(self.task_names):
60
+ raise ValueError(
61
+ "number of logits and targets should be equal to the number of tasks. "
62
+ f"got {len(logits)} logits and {len(targets)} targets for "
63
+ f"{len(self.task_names)} tasks"
64
+ )
65
+
66
+ loss = 0.0
67
+ logging_outputs = dict()
68
+ for i, task_name in enumerate(self.task_names):
69
+ criterion = self.criterions[task_name]
70
+ assert (
71
+ task_name in logits and task_name in targets
72
+ ), f"task name {task_name} not found in logits or targets"
73
+ task_logits = logits[task_name]
74
+ task_targets = targets[task_name]
75
+ task_loss, task_losses_to_log = criterion.compute_loss(
76
+ logits=task_logits, targets=task_targets, sample=sample, net_output=net_output, model=model
77
+ )
78
+ task_loss *= self.task_loss_weights[i]
79
+ sample_size = criterion.get_sample_size(sample, task_targets)
80
+
81
+ logging_outputs[f"<{task_name}>_criterion_cls"] = criterion.__class__
82
+ if len(task_losses_to_log) > 1:
83
+ logging_outputs[f"{task_name}_loss"] = task_loss.item()
84
+ for j, l in enumerate(task_losses_to_log):
85
+ logging_outputs[f"<{task_name}>_loss_{j}"] = l
86
+ else:
87
+ logging_outputs[f"<{task_name}>_loss"] = task_losses_to_log[0]
88
+ logging_outputs[f"<{task_name}>_sample_size"] = sample_size
89
+
90
+ task_logging_output = criterion.get_logging_outputs({}, task_logits, task_targets, sample)
91
+ for log, value in task_logging_output.items():
92
+ if log.startswith("_"):
93
+ log = log[1:]
94
+ logging_outputs[f"_<{task_name}>_{log}"] = value
95
+ else:
96
+ logging_outputs[f"<{task_name}>_{log}"] = value
97
+
98
+ # divide task loss by the sample size beforehand to handle different sample
99
+ # sizes for multiple criterions
100
+ loss += task_loss / logging_outputs[f"<{task_name}>_sample_size"]
101
+
102
+ # manipulate sample_size to be 1 to avoid double-dividing gradients in optimizer later
103
+ sample_size = 1
104
+
105
+ if return_net_output:
106
+ return loss, sample_size, logging_outputs, net_output
107
+ else:
108
+ return loss, sample_size, logging_outputs
109
+
110
+ @staticmethod
111
+ def reduce_metrics(logging_outputs: List[Dict[str, Any]]) -> None:
112
+ log_keys = logging_outputs[0].keys()
113
+
114
+ grouped_log_keys = defaultdict(list)
115
+ for lk in log_keys:
116
+ group = re.search(r"\<.*\>", lk)
117
+ offset = group.end() + 1
118
+ group = group.group()[1:-1]
119
+ key = lk[offset:]
120
+ if lk.startswith("_"):
121
+ key = "_" + key
122
+ grouped_log_keys[group].append(key)
123
+
124
+ total_loss = 0
125
+ for group, log_keys in grouped_log_keys.items():
126
+ criterion_cls = logging_outputs[0][f"<{group}>_criterion_cls"]
127
+ logging_output = []
128
+ for log in logging_outputs:
129
+ log_dict = {}
130
+ for log_key in set(log_keys) - {"criterion_cls"}:
131
+ if log_key.startswith("_") and f"_<{group}>{log_key}" in log:
132
+ log_dict[log_key] = log[f"_<{group}>{log_key}"]
133
+ elif f"<{group}>_{log_key}" in log:
134
+ log_dict[log_key] = log[f"<{group}>_{log_key}"]
135
+ logging_output.append(log_dict)
136
+ criterion_cls.reduce_metrics(logging_output, prefix=group)
137
+
138
+ loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_output))
139
+ sample_size = utils.item(sum(log.get("sample_size", 0) for log in logging_output))
140
+
141
+ total_loss += loss_sum / (sample_size or 1) / math.log(2)
142
+
143
+ metrics.log_scalar("loss", total_loss, 1, round=3)
144
+
145
+ def post_validate(self, stats, agg, **kwargs):
146
+ task_agg = {}
147
+ for key in agg:
148
+ for task_name in self.task_names:
149
+ if key.startswith(task_name) or key[1:].startswith(task_name):
150
+ if task_name not in task_agg:
151
+ task_agg[task_name] = {}
152
+ task_agg[task_name][key] = agg[key]
153
+ break
154
+
155
+ for task_name, task_agg in task_agg.items():
156
+ if hasattr(self.criterions[task_name], "post_validate"):
157
+ stats = self.criterions[task_name].post_validate(stats, task_agg, **kwargs)
158
+
159
+ for key in list(stats.keys()):
160
+ for task_name in self.task_names:
161
+ if key.startswith(task_name):
162
+ stat_key = key[len(task_name) + 1 :]
163
+ if f"avg_{stat_key}" not in stats:
164
+ stats[f"avg_{stat_key}"] = []
165
+ stats[f"avg_{stat_key}"].append(stats[key])
166
+ break
167
+
168
+ for key in list(stats.keys()):
169
+ if key.startswith("avg_"):
170
+ stats[key] = sum(stats[key]) / len(stats[key])
171
+ return stats
172
+
173
+ def eval(self):
174
+ super().eval()
175
+ for criterion in self.criterions.values():
176
+ criterion.eval()
177
+ return self
@@ -0,0 +1,84 @@
1
+ import math
2
+ from dataclasses import dataclass, field
3
+ from typing import List, Tuple
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ import genhpf.utils.utils as utils
9
+ from genhpf.criterions import BaseCriterion, register_criterion
10
+ from genhpf.criterions.criterion import CriterionConfig
11
+ from genhpf.loggings import metrics
12
+
13
+
14
+ @dataclass
15
+ class SimCLRCriterionConfig(CriterionConfig):
16
+ temp: float = field(default=0.1, metadata={"help": "temperature to divide logits by"})
17
+
18
+
19
+ @register_criterion("simclr_criterion", dataclass=SimCLRCriterionConfig)
20
+ class SimCLRCriterion(BaseCriterion):
21
+ def __init__(self, cfg: SimCLRCriterionConfig):
22
+ super().__init__(cfg)
23
+
24
+ self.temp = cfg.temp
25
+
26
+ def compute_loss(
27
+ self, logits: torch.Tensor, targets: torch.Tensor = None, sample=None, net_output=None, model=None
28
+ ) -> Tuple[torch.Tensor, List[float]]:
29
+ """Compute the loss given the logits and targets from the model."""
30
+ logits = F.normalize(logits, dim=1) # normalize logits
31
+
32
+ bsz = int(logits.shape[0] / 2)
33
+
34
+ mask = 1 - torch.eye(bsz * 2, dtype=torch.uint8).to(logits.device)
35
+ pos_ind = (
36
+ torch.arange(bsz * 2).to(logits.device),
37
+ 2
38
+ * torch.arange(bsz, dtype=torch.long)
39
+ .unsqueeze(1)
40
+ .repeat(1, 2)
41
+ .view(-1, 1)
42
+ .squeeze()
43
+ .to(logits.device),
44
+ )
45
+ neg_mask = torch.ones((bsz * 2, bsz * 2 - 1), dtype=torch.uint8).to(logits.device)
46
+ neg_mask[pos_ind] = 0
47
+
48
+ # Cosine similarity computation
49
+ sim_matrix = torch.matmul(logits, logits.T) # cosine similarity computation
50
+
51
+ # Eliminate similarity between same view
52
+ sim_matrix = torch.masked_select(sim_matrix, mask.bool()).view(sim_matrix.size(0), -1)
53
+
54
+ positives = sim_matrix[pos_ind].unsqueeze(1)
55
+ negatives = torch.masked_select(sim_matrix, neg_mask.bool()).view(sim_matrix.size(0), -1)
56
+
57
+ logits = torch.cat((positives, negatives), dim=1)
58
+ logits /= self.temp
59
+
60
+ target = torch.zeros((logits.size(0),), dtype=torch.long).to(logits.device)
61
+
62
+ loss = F.cross_entropy(logits, target, reduction="sum")
63
+
64
+ return loss, [loss.detach().item()]
65
+
66
+ def get_sample_size(self, sample, targets: torch.Tensor = None) -> int:
67
+ return sample["net_input"]["input_ids"].size(0)
68
+
69
+ def get_logging_outputs(self, logging_output, logits, target, sample=None, net_output=None):
70
+ return logging_output
71
+
72
+ @staticmethod
73
+ def reduce_metrics(logging_outputs, prefix: str = None) -> None:
74
+ """Aggregate logging outputs from data parallel training."""
75
+ if prefix is None:
76
+ prefix = ""
77
+ elif prefix is not None and not prefix.endswith("_"):
78
+ prefix = prefix + "_"
79
+
80
+ loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
81
+
82
+ sample_size = utils.item(sum(log.get("sample_size", 0) for log in logging_outputs))
83
+
84
+ metrics.log_scalar(f"{prefix}loss", loss_sum / (sample_size or 1) / math.log(2), sample_size, round=3)