genhpf 1.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genhpf/__init__.py +9 -0
- genhpf/configs/__init__.py +23 -0
- genhpf/configs/config.yaml +8 -0
- genhpf/configs/configs.py +240 -0
- genhpf/configs/constants.py +29 -0
- genhpf/configs/initialize.py +58 -0
- genhpf/configs/utils.py +29 -0
- genhpf/criterions/__init__.py +74 -0
- genhpf/criterions/binary_cross_entropy.py +114 -0
- genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
- genhpf/criterions/criterion.py +87 -0
- genhpf/criterions/cross_entropy.py +202 -0
- genhpf/criterions/multi_task_criterion.py +177 -0
- genhpf/criterions/simclr_criterion.py +84 -0
- genhpf/criterions/wav2vec2_criterion.py +130 -0
- genhpf/datasets/__init__.py +84 -0
- genhpf/datasets/dataset.py +109 -0
- genhpf/datasets/genhpf_dataset.py +451 -0
- genhpf/datasets/meds_dataset.py +232 -0
- genhpf/loggings/__init__.py +0 -0
- genhpf/loggings/meters.py +374 -0
- genhpf/loggings/metrics.py +155 -0
- genhpf/loggings/progress_bar.py +445 -0
- genhpf/models/__init__.py +73 -0
- genhpf/models/genhpf.py +244 -0
- genhpf/models/genhpf_mlm.py +64 -0
- genhpf/models/genhpf_predictor.py +73 -0
- genhpf/models/genhpf_simclr.py +58 -0
- genhpf/models/genhpf_wav2vec2.py +304 -0
- genhpf/modules/__init__.py +15 -0
- genhpf/modules/gather_layer.py +23 -0
- genhpf/modules/grad_multiply.py +12 -0
- genhpf/modules/gumbel_vector_quantizer.py +204 -0
- genhpf/modules/identity_layer.py +8 -0
- genhpf/modules/layer_norm.py +27 -0
- genhpf/modules/positional_encoding.py +24 -0
- genhpf/scripts/__init__.py +0 -0
- genhpf/scripts/preprocess/__init__.py +0 -0
- genhpf/scripts/preprocess/genhpf/README.md +75 -0
- genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
- genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
- genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
- genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
- genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
- genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
- genhpf/scripts/preprocess/genhpf/main.py +175 -0
- genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
- genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
- genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
- genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
- genhpf/scripts/preprocess/manifest.py +83 -0
- genhpf/scripts/preprocess/preprocess_meds.py +674 -0
- genhpf/scripts/test.py +264 -0
- genhpf/scripts/train.py +365 -0
- genhpf/trainer.py +370 -0
- genhpf/utils/checkpoint_utils.py +171 -0
- genhpf/utils/data_utils.py +130 -0
- genhpf/utils/distributed_utils.py +497 -0
- genhpf/utils/file_io.py +170 -0
- genhpf/utils/pdb.py +38 -0
- genhpf/utils/utils.py +204 -0
- genhpf-1.0.11.dist-info/LICENSE +21 -0
- genhpf-1.0.11.dist-info/METADATA +202 -0
- genhpf-1.0.11.dist-info/RECORD +67 -0
- genhpf-1.0.11.dist-info/WHEEL +5 -0
- genhpf-1.0.11.dist-info/entry_points.txt +6 -0
- genhpf-1.0.11.dist-info/top_level.txt +1 -0
genhpf/models/genhpf.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
from omegaconf import II
|
|
8
|
+
|
|
9
|
+
from genhpf.configs import BaseConfig, ChoiceEnum
|
|
10
|
+
from genhpf.modules import PositionalEncoding
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
GENHPF_MODEL_ARCH_CHOICES = ChoiceEnum(["hierarchical", "flattened"])
|
|
15
|
+
GENHPF_AGGREGATOR_ARCH_CHOICES = ChoiceEnum(["transformer", "performer"])
|
|
16
|
+
GENHPF_EMBEDDING_METHOD_CHOICES = ChoiceEnum(["code", "text"])
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class GenHPFConfig(BaseConfig):
|
|
21
|
+
structure: GENHPF_MODEL_ARCH_CHOICES = field(
|
|
22
|
+
default="hierarchical",
|
|
23
|
+
metadata={"help": "Architecture choice for GenHPF. Choose from hierarchical or flattened"},
|
|
24
|
+
)
|
|
25
|
+
embedding_method: GENHPF_EMBEDDING_METHOD_CHOICES = field(
|
|
26
|
+
default="text", metadata={"help": "Embedding method choice for GenHPF. Choose from code or text"}
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
encoder_max_seq_len: int = field(
|
|
30
|
+
default=128,
|
|
31
|
+
metadata={
|
|
32
|
+
"help": "max sequence length for the event encoder, only used when structure is "
|
|
33
|
+
"hierarchical. this is the max number of tokens in an event."
|
|
34
|
+
},
|
|
35
|
+
)
|
|
36
|
+
agg_max_seq_len: int = field(
|
|
37
|
+
default=256,
|
|
38
|
+
metadata={
|
|
39
|
+
"help": "max sequence length for the event aggregator. In the hierarchical structure, "
|
|
40
|
+
"this is the max number of events in a sample. In the flattened structure, this is the"
|
|
41
|
+
"max sequence length of the flattened input."
|
|
42
|
+
},
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
# configs for event encoder in hierarchical structure
|
|
46
|
+
encoder_layers: int = field(default=2, metadata={"help": "num encoder layers in the transformer"})
|
|
47
|
+
encoder_embed_dim: int = field(default=128, metadata={"help": "encoder embedding dimension"})
|
|
48
|
+
encoder_ffn_embed_dim: int = field(default=512, metadata={"help": "encoder embedding dimension for FFN"})
|
|
49
|
+
encoder_attention_heads: int = field(
|
|
50
|
+
default=4, metadata={"help": "num attention heads in the transformer"}
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
# configs for event aggregator
|
|
54
|
+
agg_arch: GENHPF_AGGREGATOR_ARCH_CHOICES = field(
|
|
55
|
+
default="transformer",
|
|
56
|
+
metadata={
|
|
57
|
+
"help": "Architecture choice for the event aggregator. Choose from transformer or " "performer"
|
|
58
|
+
},
|
|
59
|
+
)
|
|
60
|
+
agg_layers: int = field(default=2, metadata={"help": "num layers in the transformer"})
|
|
61
|
+
agg_embed_dim: int = field(default=128, metadata={"help": "hidden dimension for the event aggregator"})
|
|
62
|
+
agg_ffn_embed_dim: int = field(
|
|
63
|
+
default=512, metadata={"help": "hidden dimension for the FFN in the event aggregator"}
|
|
64
|
+
)
|
|
65
|
+
agg_attention_heads: int = field(
|
|
66
|
+
default=4, metadata={"help": "num attention heads for the event aggregator"}
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
dropout: float = field(default=0.2, metadata={"help": "dropout probability"})
|
|
70
|
+
|
|
71
|
+
from_pretrained: Optional[str] = field(
|
|
72
|
+
default=None, metadata={"help": "path to the pretrained model if available"}
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
vocab_size: int = II("dataset.vocab_size")
|
|
76
|
+
debug: bool = II("common.debug")
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class GenHPF(nn.Module):
|
|
80
|
+
def __init__(self, cfg: GenHPFConfig):
|
|
81
|
+
super().__init__()
|
|
82
|
+
self.cfg = cfg
|
|
83
|
+
|
|
84
|
+
if cfg.debug:
|
|
85
|
+
cfg.encoder_layers = 1
|
|
86
|
+
cfg.encoder_embed_dim = 32
|
|
87
|
+
cfg.encoder_ffn_embed_dim = 128
|
|
88
|
+
cfg.encoder_attention_heads = 2
|
|
89
|
+
cfg.agg_layers = 1
|
|
90
|
+
cfg.agg_embed_dim = 32
|
|
91
|
+
cfg.agg_ffn_embed_dim = 128
|
|
92
|
+
cfg.agg_attention_heads = 2
|
|
93
|
+
|
|
94
|
+
self.structure = cfg.structure
|
|
95
|
+
assert self.structure in GENHPF_MODEL_ARCH_CHOICES
|
|
96
|
+
|
|
97
|
+
self.embedding_method = cfg.embedding_method
|
|
98
|
+
assert self.embedding_method in GENHPF_EMBEDDING_METHOD_CHOICES
|
|
99
|
+
|
|
100
|
+
self.word_embeddings = nn.Embedding(cfg.vocab_size, cfg.encoder_embed_dim, padding_idx=0)
|
|
101
|
+
if self.embedding_method == "text":
|
|
102
|
+
# we currently use 7 token types and 16 digit places in this version
|
|
103
|
+
self.token_type_vocab_size = 7
|
|
104
|
+
self.digit_place_vocab_size = 16
|
|
105
|
+
self.token_type_embeddings = nn.Embedding(
|
|
106
|
+
self.token_type_vocab_size, cfg.encoder_embed_dim, padding_idx=0
|
|
107
|
+
)
|
|
108
|
+
self.digit_place_embeddings = nn.Embedding(
|
|
109
|
+
self.digit_place_vocab_size, cfg.encoder_embed_dim, padding_idx=0
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
max_length = cfg.encoder_max_seq_len if self.structure == "hierarchical" else cfg.agg_max_seq_len
|
|
113
|
+
self.positional_encodings = PositionalEncoding(cfg.encoder_embed_dim, cfg.dropout, max_length)
|
|
114
|
+
|
|
115
|
+
self.event_encoder = None
|
|
116
|
+
if self.structure == "hierarchical" and self.embedding_method == "text":
|
|
117
|
+
self.encoder_layer_norm = nn.LayerNorm(cfg.encoder_embed_dim, eps=1e-12)
|
|
118
|
+
encoder_layer = nn.TransformerEncoderLayer(
|
|
119
|
+
cfg.encoder_embed_dim,
|
|
120
|
+
cfg.encoder_attention_heads,
|
|
121
|
+
cfg.encoder_ffn_embed_dim,
|
|
122
|
+
cfg.dropout,
|
|
123
|
+
batch_first=True,
|
|
124
|
+
)
|
|
125
|
+
self.event_encoder = nn.TransformerEncoder(encoder_layer, self.cfg.encoder_layers)
|
|
126
|
+
self.post_encode_proj = nn.Linear(cfg.encoder_embed_dim, cfg.agg_embed_dim)
|
|
127
|
+
self.event_positional_encodings = PositionalEncoding(
|
|
128
|
+
cfg.agg_embed_dim, cfg.dropout, cfg.agg_max_seq_len
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
self.agg_layer_norm = nn.LayerNorm(cfg.agg_embed_dim, eps=1e-12)
|
|
132
|
+
if cfg.agg_arch == "transformer":
|
|
133
|
+
agg_layer = nn.TransformerEncoderLayer(
|
|
134
|
+
cfg.agg_embed_dim,
|
|
135
|
+
cfg.agg_attention_heads,
|
|
136
|
+
cfg.agg_ffn_embed_dim,
|
|
137
|
+
cfg.dropout,
|
|
138
|
+
batch_first=True,
|
|
139
|
+
)
|
|
140
|
+
self.event_aggregator = nn.TransformerEncoder(agg_layer, cfg.agg_layers)
|
|
141
|
+
elif cfg.agg_arch == "performer":
|
|
142
|
+
from performer_pytorch import Performer
|
|
143
|
+
|
|
144
|
+
self.event_aggregator = Performer(
|
|
145
|
+
dim=cfg.agg_embed_dim,
|
|
146
|
+
depth=cfg.agg_layers,
|
|
147
|
+
heads=cfg.agg_attention_heads,
|
|
148
|
+
dim_head=64,
|
|
149
|
+
nb_features=64,
|
|
150
|
+
reversible=True,
|
|
151
|
+
generalized_attention=True,
|
|
152
|
+
ff_dropout=cfg.dropout,
|
|
153
|
+
attn_dropout=cfg.dropout,
|
|
154
|
+
shift_tokens=True,
|
|
155
|
+
)
|
|
156
|
+
else:
|
|
157
|
+
raise NotImplementedError(f"Unsupported event aggregator architecture: {cfg.agg_arch}")
|
|
158
|
+
|
|
159
|
+
@classmethod
|
|
160
|
+
def build_model(cls, cfg):
|
|
161
|
+
"""Build a new model instance."""
|
|
162
|
+
if cfg.from_pretrained:
|
|
163
|
+
return cls.from_pretrained(cfg, cfg.from_pretrained)
|
|
164
|
+
else:
|
|
165
|
+
return cls(cfg)
|
|
166
|
+
|
|
167
|
+
@classmethod
|
|
168
|
+
def from_pretrained(cls, cfg, checkpoint_path):
|
|
169
|
+
model = cls(cfg)
|
|
170
|
+
state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
|
|
171
|
+
|
|
172
|
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
|
173
|
+
assert len(unexpected_keys) == 0, f"unexpected keys: {unexpected_keys}"
|
|
174
|
+
|
|
175
|
+
finetuning_parameter_names = []
|
|
176
|
+
if hasattr(model, "get_finetuning_parameter_names"):
|
|
177
|
+
finetuning_parameter_names = model.get_finetuning_parameter_names()
|
|
178
|
+
missing_keys = list(set(missing_keys) - set(finetuning_parameter_names))
|
|
179
|
+
assert len(missing_keys) == 0, f"missing keys: {missing_keys}"
|
|
180
|
+
|
|
181
|
+
logger.info(f"loaded pre-trained model from {checkpoint_path}")
|
|
182
|
+
|
|
183
|
+
return model
|
|
184
|
+
|
|
185
|
+
def get_logits(self, sample, net_output):
|
|
186
|
+
"""Get logits from the model output."""
|
|
187
|
+
raise NotImplementedError
|
|
188
|
+
|
|
189
|
+
def get_targets(self, sample, net_output):
|
|
190
|
+
"""Get targets from the sample or model output."""
|
|
191
|
+
raise NotImplementedError
|
|
192
|
+
|
|
193
|
+
def forward(
|
|
194
|
+
self,
|
|
195
|
+
input_ids: torch.Tensor,
|
|
196
|
+
type_ids: torch.Tensor = None,
|
|
197
|
+
dpe_ids: torch.Tensor = None,
|
|
198
|
+
padding_mask: torch.Tensor = None,
|
|
199
|
+
encoder_only: bool = False,
|
|
200
|
+
**kwargs,
|
|
201
|
+
):
|
|
202
|
+
x = self.word_embeddings(input_ids)
|
|
203
|
+
if self.embedding_method == "text":
|
|
204
|
+
if type_ids is not None:
|
|
205
|
+
x += self.token_type_embeddings(type_ids)
|
|
206
|
+
if dpe_ids is not None:
|
|
207
|
+
x += self.digit_place_embeddings(dpe_ids)
|
|
208
|
+
|
|
209
|
+
if self.structure == "hierarchical":
|
|
210
|
+
assert input_ids.ndim == 3 # (batch, num_events, num_words)
|
|
211
|
+
batch_size, num_events = input_ids.shape[0], input_ids.shape[1]
|
|
212
|
+
|
|
213
|
+
x = x.view(batch_size * num_events, -1, self.cfg.encoder_embed_dim)
|
|
214
|
+
x = self.positional_encodings(x)
|
|
215
|
+
x = self.encoder_layer_norm(x)
|
|
216
|
+
if padding_mask is None:
|
|
217
|
+
padding_mask = input_ids.view(batch_size * num_events, -1).eq(0).to(x.device)
|
|
218
|
+
# x: (batch * num_events, num_words, embed_dim)
|
|
219
|
+
x = self.event_encoder(x, src_key_padding_mask=padding_mask)
|
|
220
|
+
|
|
221
|
+
if padding_mask.any():
|
|
222
|
+
x[padding_mask] = 0
|
|
223
|
+
x = torch.div(x.sum(dim=1), (x != 0).sum(dim=1))
|
|
224
|
+
x = self.post_encode_proj(x).view(batch_size, num_events, -1) # (batch, num_events, embed_dim)
|
|
225
|
+
|
|
226
|
+
padding_mask = input_ids[:, :, 1].eq(0).to(x.device)
|
|
227
|
+
|
|
228
|
+
x = self.event_positional_encodings(x)
|
|
229
|
+
else:
|
|
230
|
+
assert input_ids.ndim == 2 # (batch, seq_len)
|
|
231
|
+
x = self.positional_encodings(x) # (batch, seq_len, embed_dim)
|
|
232
|
+
if padding_mask is None:
|
|
233
|
+
padding_mask = input_ids.eq(0).to(x.device)
|
|
234
|
+
|
|
235
|
+
if encoder_only:
|
|
236
|
+
return {
|
|
237
|
+
"x": x,
|
|
238
|
+
"padding_mask": padding_mask,
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
# breakpoint()
|
|
242
|
+
x = self.event_aggregator(x, src_key_padding_mask=padding_mask)
|
|
243
|
+
|
|
244
|
+
return x, padding_mask
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
from omegaconf import II
|
|
7
|
+
|
|
8
|
+
from genhpf.models import register_model
|
|
9
|
+
from genhpf.models.genhpf import GenHPF, GenHPFConfig
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class GenHPFMLMConfig(GenHPFConfig):
|
|
16
|
+
ignore_index: int = II("dataset.ignore_index")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@register_model("genhpf_mlm", dataclass=GenHPFMLMConfig)
|
|
20
|
+
class GenHPFMLM(GenHPF):
|
|
21
|
+
def __init__(self, cfg: GenHPFMLMConfig):
|
|
22
|
+
super().__init__(cfg)
|
|
23
|
+
|
|
24
|
+
self.ignore_index = cfg.ignore_index
|
|
25
|
+
|
|
26
|
+
self.input_ids_proj = nn.Linear(cfg.agg_embed_dim, cfg.vocab_size)
|
|
27
|
+
|
|
28
|
+
@classmethod
|
|
29
|
+
def build_model(cls, cfg):
|
|
30
|
+
"""Build a new model instance."""
|
|
31
|
+
return cls(cfg)
|
|
32
|
+
|
|
33
|
+
def get_logits(self, sample, net_output):
|
|
34
|
+
masked_indices = torch.where(
|
|
35
|
+
(sample["input_label"] > 0) & (sample["input_label"] != self.ignore_index)
|
|
36
|
+
)
|
|
37
|
+
return net_output["input_ids"][masked_indices]
|
|
38
|
+
|
|
39
|
+
def get_targets(self, sample, net_output):
|
|
40
|
+
masked_indices = torch.where(
|
|
41
|
+
(sample["input_label"] > 0) & (sample["input_label"] != self.ignore_index)
|
|
42
|
+
)
|
|
43
|
+
return sample["input_label"][masked_indices]
|
|
44
|
+
|
|
45
|
+
def forward(
|
|
46
|
+
self,
|
|
47
|
+
input_ids: torch.Tensor,
|
|
48
|
+
type_ids: torch.Tensor = None,
|
|
49
|
+
dpe_ids: torch.Tensor = None,
|
|
50
|
+
padding_mask: torch.Tensor = None,
|
|
51
|
+
**kwargs,
|
|
52
|
+
):
|
|
53
|
+
x, padding_mask = super().forward(
|
|
54
|
+
input_ids=input_ids, type_ids=type_ids, dpe_ids=dpe_ids, padding_mask=padding_mask, **kwargs
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
input_ids = self.input_ids_proj(x)
|
|
58
|
+
|
|
59
|
+
return {"input_ids": input_ids, "padding_mask": padding_mask}
|
|
60
|
+
|
|
61
|
+
def get_pretraining_parameter_names(self):
|
|
62
|
+
ret = []
|
|
63
|
+
ret.extend(["input_ids_proj" + "." + x[0] for x in self.input_ids_proj.named_parameters()])
|
|
64
|
+
return ret
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import List
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
from omegaconf import II
|
|
8
|
+
|
|
9
|
+
from genhpf.models import register_model
|
|
10
|
+
from genhpf.models.genhpf import GenHPF, GenHPFConfig
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class GenHPFPredictorConfig(GenHPFConfig):
|
|
17
|
+
tasks: List[str] = II("criterion.task_names")
|
|
18
|
+
num_labels: List[int] = II("criterion.num_labels")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@register_model("genhpf_predictor", dataclass=GenHPFPredictorConfig)
|
|
22
|
+
class GenHPFPredictor(GenHPF):
|
|
23
|
+
def __init__(self, cfg: GenHPFPredictorConfig):
|
|
24
|
+
super().__init__(cfg)
|
|
25
|
+
|
|
26
|
+
self.num_labels = cfg.num_labels
|
|
27
|
+
self.tasks = cfg.tasks
|
|
28
|
+
assert len(self.num_labels) == len(
|
|
29
|
+
cfg.tasks
|
|
30
|
+
), "The number of num_labels must be equal to the number of tasks"
|
|
31
|
+
|
|
32
|
+
self.final_proj = nn.ModuleDict()
|
|
33
|
+
for i, task in enumerate(cfg.tasks):
|
|
34
|
+
self.final_proj[task] = nn.Linear(cfg.agg_embed_dim, self.num_labels[i])
|
|
35
|
+
|
|
36
|
+
def get_logits(self, sample, net_output):
|
|
37
|
+
if len(self.tasks) == 1:
|
|
38
|
+
return net_output[self.tasks[0]]
|
|
39
|
+
else:
|
|
40
|
+
return net_output
|
|
41
|
+
|
|
42
|
+
def get_targets(self, sample, net_output):
|
|
43
|
+
if len(self.tasks) == 1:
|
|
44
|
+
return sample["label"][self.tasks[0]]
|
|
45
|
+
else:
|
|
46
|
+
return sample["label"]
|
|
47
|
+
|
|
48
|
+
def forward(
|
|
49
|
+
self,
|
|
50
|
+
input_ids: torch.Tensor,
|
|
51
|
+
type_ids: torch.Tensor = None,
|
|
52
|
+
dpe_ids: torch.Tensor = None,
|
|
53
|
+
padding_mask: torch.Tensor = None,
|
|
54
|
+
**kwargs,
|
|
55
|
+
):
|
|
56
|
+
x, padding_mask = super().forward(
|
|
57
|
+
input_ids=input_ids, type_ids=type_ids, dpe_ids=dpe_ids, padding_mask=padding_mask, **kwargs
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
if padding_mask is not None and padding_mask.any():
|
|
61
|
+
x[padding_mask] = 0
|
|
62
|
+
x = torch.div(x.sum(dim=1), (x != 0).sum(dim=1))
|
|
63
|
+
|
|
64
|
+
ret = {}
|
|
65
|
+
for task, proj in self.final_proj.items():
|
|
66
|
+
ret[task] = proj(x)
|
|
67
|
+
|
|
68
|
+
return ret
|
|
69
|
+
|
|
70
|
+
def get_finetuning_parameter_names(self):
|
|
71
|
+
ret = []
|
|
72
|
+
ret.extend(["final_proj" + "." + x[0] for x in self.final_proj.named_parameters()])
|
|
73
|
+
return ret
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from genhpf.models import register_model
|
|
7
|
+
from genhpf.models.genhpf import GenHPF, GenHPFConfig
|
|
8
|
+
from genhpf.modules import GatherLayer
|
|
9
|
+
from genhpf.utils import distributed_utils as dist_utils
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class GenHPFSimCLRConfig(GenHPFConfig):
|
|
16
|
+
all_gather: bool = field(
|
|
17
|
+
default=True, metadata={"help": "whether or not to apply all gather across different gpus"}
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@register_model("genhpf_simclr", dataclass=GenHPFSimCLRConfig)
|
|
22
|
+
class GenHPFSimCLR(GenHPF):
|
|
23
|
+
def __init__(self, cfg: GenHPFSimCLRConfig):
|
|
24
|
+
super().__init__(cfg)
|
|
25
|
+
|
|
26
|
+
self.all_gather = cfg.all_gather
|
|
27
|
+
|
|
28
|
+
@classmethod
|
|
29
|
+
def build_model(cls, cfg):
|
|
30
|
+
"""Build a new model instance."""
|
|
31
|
+
return cls(cfg)
|
|
32
|
+
|
|
33
|
+
def get_logits(self, sample, net_output):
|
|
34
|
+
return net_output
|
|
35
|
+
|
|
36
|
+
def get_targets(self, sample, net_output):
|
|
37
|
+
return None
|
|
38
|
+
|
|
39
|
+
def forward(
|
|
40
|
+
self,
|
|
41
|
+
input_ids: torch.Tensor,
|
|
42
|
+
type_ids: torch.Tensor = None,
|
|
43
|
+
dpe_ids: torch.Tensor = None,
|
|
44
|
+
padding_mask: torch.Tensor = None,
|
|
45
|
+
**kwargs,
|
|
46
|
+
):
|
|
47
|
+
x, padding_mask = super().forward(
|
|
48
|
+
input_ids=input_ids, type_ids=type_ids, dpe_ids=dpe_ids, padding_mask=padding_mask, **kwargs
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
if padding_mask is not None and padding_mask.any():
|
|
52
|
+
x[padding_mask] = 0
|
|
53
|
+
x = torch.div(x.sum(dim=1), (x != 0).sum(dim=1))
|
|
54
|
+
|
|
55
|
+
if self.all_gather and dist_utils.get_data_parallel_world_size() > 1:
|
|
56
|
+
x = torch.cat(GatherLayer.apply(x), dim=0)
|
|
57
|
+
|
|
58
|
+
return x
|