genhpf 1.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genhpf/__init__.py +9 -0
- genhpf/configs/__init__.py +23 -0
- genhpf/configs/config.yaml +8 -0
- genhpf/configs/configs.py +240 -0
- genhpf/configs/constants.py +29 -0
- genhpf/configs/initialize.py +58 -0
- genhpf/configs/utils.py +29 -0
- genhpf/criterions/__init__.py +74 -0
- genhpf/criterions/binary_cross_entropy.py +114 -0
- genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
- genhpf/criterions/criterion.py +87 -0
- genhpf/criterions/cross_entropy.py +202 -0
- genhpf/criterions/multi_task_criterion.py +177 -0
- genhpf/criterions/simclr_criterion.py +84 -0
- genhpf/criterions/wav2vec2_criterion.py +130 -0
- genhpf/datasets/__init__.py +84 -0
- genhpf/datasets/dataset.py +109 -0
- genhpf/datasets/genhpf_dataset.py +451 -0
- genhpf/datasets/meds_dataset.py +232 -0
- genhpf/loggings/__init__.py +0 -0
- genhpf/loggings/meters.py +374 -0
- genhpf/loggings/metrics.py +155 -0
- genhpf/loggings/progress_bar.py +445 -0
- genhpf/models/__init__.py +73 -0
- genhpf/models/genhpf.py +244 -0
- genhpf/models/genhpf_mlm.py +64 -0
- genhpf/models/genhpf_predictor.py +73 -0
- genhpf/models/genhpf_simclr.py +58 -0
- genhpf/models/genhpf_wav2vec2.py +304 -0
- genhpf/modules/__init__.py +15 -0
- genhpf/modules/gather_layer.py +23 -0
- genhpf/modules/grad_multiply.py +12 -0
- genhpf/modules/gumbel_vector_quantizer.py +204 -0
- genhpf/modules/identity_layer.py +8 -0
- genhpf/modules/layer_norm.py +27 -0
- genhpf/modules/positional_encoding.py +24 -0
- genhpf/scripts/__init__.py +0 -0
- genhpf/scripts/preprocess/__init__.py +0 -0
- genhpf/scripts/preprocess/genhpf/README.md +75 -0
- genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
- genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
- genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
- genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
- genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
- genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
- genhpf/scripts/preprocess/genhpf/main.py +175 -0
- genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
- genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
- genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
- genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
- genhpf/scripts/preprocess/manifest.py +83 -0
- genhpf/scripts/preprocess/preprocess_meds.py +674 -0
- genhpf/scripts/test.py +264 -0
- genhpf/scripts/train.py +365 -0
- genhpf/trainer.py +370 -0
- genhpf/utils/checkpoint_utils.py +171 -0
- genhpf/utils/data_utils.py +130 -0
- genhpf/utils/distributed_utils.py +497 -0
- genhpf/utils/file_io.py +170 -0
- genhpf/utils/pdb.py +38 -0
- genhpf/utils/utils.py +204 -0
- genhpf-1.0.11.dist-info/LICENSE +21 -0
- genhpf-1.0.11.dist-info/METADATA +202 -0
- genhpf-1.0.11.dist-info/RECORD +67 -0
- genhpf-1.0.11.dist-info/WHEEL +5 -0
- genhpf-1.0.11.dist-info/entry_points.txt +6 -0
- genhpf-1.0.11.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
|
|
8
|
+
from genhpf.models import register_model
|
|
9
|
+
from genhpf.models.genhpf import GenHPF, GenHPFConfig
|
|
10
|
+
from genhpf.modules import GradMultiply, GumbelVectorQuantizer, LayerNorm
|
|
11
|
+
from genhpf.utils import utils
|
|
12
|
+
from genhpf.utils.data_utils import compute_mask_indices
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class GenHPFWav2Vec2Config(GenHPFConfig):
|
|
19
|
+
logit_temp: float = field(default=0.1, metadata={"help": "temperature to divide logits by"})
|
|
20
|
+
latent_vars: int = field(
|
|
21
|
+
default=320, metadata={"help": "number of latent variables V in each group of the codebook"}
|
|
22
|
+
)
|
|
23
|
+
latent_groups: int = field(
|
|
24
|
+
default=2, metadata={"help": "number of groups G of latent variables in the codebook"}
|
|
25
|
+
)
|
|
26
|
+
latent_temp: Tuple[float, float, float] = field(
|
|
27
|
+
default=(2, 0.5, 0.999995),
|
|
28
|
+
metadata={
|
|
29
|
+
"help": "temperature for latent variable sampling. "
|
|
30
|
+
"can be tuple of 3 values (start, end, decay)"
|
|
31
|
+
},
|
|
32
|
+
)
|
|
33
|
+
final_dim: int = field(
|
|
34
|
+
default=128, metadata={"help": "project final representations and targets to this many dimensions."}
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
num_negatives: int = field(
|
|
38
|
+
default=25, metadata={"help": "number of negative examples from the same sample"}
|
|
39
|
+
)
|
|
40
|
+
codebook_negatives: int = field(default=0, metadata={"help": "number of negative examples codebook"})
|
|
41
|
+
|
|
42
|
+
# mask
|
|
43
|
+
mask_prob: float = field(default=0.65, metadata={"help": "probability of replacing a token with mask"})
|
|
44
|
+
mask_length: int = field(default=1, metadata={"help": "mask length"})
|
|
45
|
+
no_mask_overlap: bool = field(default=False, metadata={"help": "whether to allow masks to overlap"})
|
|
46
|
+
mask_min_space: int = field(
|
|
47
|
+
default=0, metadata={"help": "min space between spans (if no overlap is enabled)"}
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
feature_grad_mult: float = field(
|
|
51
|
+
default=0.1, metadata={"help": "multiply event encoder gradients by this"}
|
|
52
|
+
)
|
|
53
|
+
dropout_input: float = field(default=0.1, metadata={"help": "dropout to apply to the input"})
|
|
54
|
+
dropout_features: float = field(default=0.1, metadata={"help": "dropout to apply to the features"})
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@register_model("genhpf_wav2vec2", dataclass=GenHPFWav2Vec2Config)
|
|
58
|
+
class GenHPFWav2Vec2(GenHPF):
|
|
59
|
+
def __init__(self, cfg: GenHPFWav2Vec2Config):
|
|
60
|
+
super().__init__(cfg)
|
|
61
|
+
|
|
62
|
+
self.logit_temp = cfg.logit_temp
|
|
63
|
+
|
|
64
|
+
self.mask_prob = cfg.mask_prob
|
|
65
|
+
self.mask_length = cfg.mask_length
|
|
66
|
+
self.no_mask_overlap = cfg.no_mask_overlap
|
|
67
|
+
self.mask_min_space = cfg.mask_min_space
|
|
68
|
+
|
|
69
|
+
self.feature_grad_mult = cfg.feature_grad_mult
|
|
70
|
+
|
|
71
|
+
self.codebook_negatives = cfg.codebook_negatives
|
|
72
|
+
|
|
73
|
+
self.quantizer = GumbelVectorQuantizer(
|
|
74
|
+
dim=cfg.agg_embed_dim,
|
|
75
|
+
num_vars=cfg.latent_vars,
|
|
76
|
+
temp=cfg.latent_temp,
|
|
77
|
+
groups=cfg.latent_groups,
|
|
78
|
+
combine_groups=False,
|
|
79
|
+
vq_dim=cfg.agg_embed_dim,
|
|
80
|
+
time_first=True,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
self.layer_norm = LayerNorm(cfg.agg_embed_dim)
|
|
84
|
+
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
|
85
|
+
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
|
86
|
+
self.project_q = nn.Linear(cfg.agg_embed_dim, cfg.final_dim)
|
|
87
|
+
self.final_proj = nn.Linear(cfg.agg_embed_dim, cfg.final_dim)
|
|
88
|
+
self.mask_emb = nn.Parameter(torch.FloatTensor(cfg.agg_embed_dim).uniform_())
|
|
89
|
+
self.cross_sample_negatives = 0
|
|
90
|
+
self.n_negatives = cfg.num_negatives
|
|
91
|
+
|
|
92
|
+
@classmethod
|
|
93
|
+
def build_model(cls, cfg):
|
|
94
|
+
"""Build a new model instance."""
|
|
95
|
+
return cls(cfg)
|
|
96
|
+
|
|
97
|
+
def compute_preds(self, x, y, negatives):
|
|
98
|
+
neg_is_pos = (y == negatives).all(-1)
|
|
99
|
+
y = y.unsqueeze(0)
|
|
100
|
+
targets = torch.cat([y, negatives], dim=0)
|
|
101
|
+
|
|
102
|
+
logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1)
|
|
103
|
+
|
|
104
|
+
logits = logits / self.logit_temp
|
|
105
|
+
logits = logits.type_as(x)
|
|
106
|
+
|
|
107
|
+
if neg_is_pos.any():
|
|
108
|
+
logits[1:][neg_is_pos] = float("-inf")
|
|
109
|
+
|
|
110
|
+
return logits
|
|
111
|
+
|
|
112
|
+
def apply_mask(
|
|
113
|
+
self,
|
|
114
|
+
x: torch.Tensor,
|
|
115
|
+
padding_mask: torch.Tensor,
|
|
116
|
+
mask_indices=None,
|
|
117
|
+
):
|
|
118
|
+
bsz, tsz, csz = x.shape
|
|
119
|
+
|
|
120
|
+
if self.mask_prob > 0:
|
|
121
|
+
if mask_indices is None:
|
|
122
|
+
mask_indices = compute_mask_indices(
|
|
123
|
+
shape=(bsz, tsz),
|
|
124
|
+
padding_mask=padding_mask,
|
|
125
|
+
mask_prob=self.mask_prob,
|
|
126
|
+
mask_length=self.mask_length,
|
|
127
|
+
mask_type="static",
|
|
128
|
+
min_masks=2,
|
|
129
|
+
no_overlap=self.no_mask_overlap,
|
|
130
|
+
min_space=self.mask_min_space,
|
|
131
|
+
)
|
|
132
|
+
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
|
133
|
+
x[mask_indices] = self.mask_emb
|
|
134
|
+
else:
|
|
135
|
+
mask_indices = None
|
|
136
|
+
|
|
137
|
+
return x, mask_indices
|
|
138
|
+
|
|
139
|
+
def sample_negatives(self, y, num):
|
|
140
|
+
if self.n_negatives == 0 and self.cross_sample_negatives == 0:
|
|
141
|
+
return y.new(0)
|
|
142
|
+
|
|
143
|
+
batch_size, time_size, feature_size = y.shape
|
|
144
|
+
y = y.view(-1, feature_size) # B x T x C -> (B x T) x C
|
|
145
|
+
|
|
146
|
+
cross_high = time_size * batch_size
|
|
147
|
+
high = time_size
|
|
148
|
+
with torch.no_grad():
|
|
149
|
+
assert high > 1, f"{batch_size, time_size, feature_size}"
|
|
150
|
+
|
|
151
|
+
if self.n_negatives > 0:
|
|
152
|
+
time_sizes = utils.buffered_arange(num).unsqueeze(-1).expand(-1, self.n_negatives).flatten()
|
|
153
|
+
neg_idxs = torch.randint(low=0, high=high - 1, size=(batch_size, self.n_negatives * num))
|
|
154
|
+
neg_idxs[neg_idxs >= time_sizes] += 1
|
|
155
|
+
|
|
156
|
+
if self.cross_sample_negatives > 0:
|
|
157
|
+
time_sizes = (
|
|
158
|
+
utils.buffered_arange(num).unsqueeze(-1).expand(-1, self.cross_sample_negatives).flatten()
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
cross_neg_idxs = torch.randint(
|
|
162
|
+
low=0, high=cross_high - 1, size=(batch_size, self.cross_sample_negatives * num)
|
|
163
|
+
)
|
|
164
|
+
cross_neg_idxs[cross_neg_idxs >= time_sizes] += 1
|
|
165
|
+
|
|
166
|
+
if self.n_negatives > 0:
|
|
167
|
+
for i in range(1, batch_size):
|
|
168
|
+
neg_idxs[i] += i * high
|
|
169
|
+
else:
|
|
170
|
+
neg_idxs = cross_neg_idxs
|
|
171
|
+
|
|
172
|
+
if self.cross_sample_negatives > 0 and self.n_negatives > 0:
|
|
173
|
+
neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1)
|
|
174
|
+
|
|
175
|
+
negs = y[neg_idxs.view(-1)]
|
|
176
|
+
negs = negs.view(
|
|
177
|
+
batch_size, num, self.n_negatives + self.cross_sample_negatives, feature_size
|
|
178
|
+
).permute(
|
|
179
|
+
2, 0, 1, 3
|
|
180
|
+
) # to N x B x T x C
|
|
181
|
+
|
|
182
|
+
return negs, neg_idxs
|
|
183
|
+
|
|
184
|
+
def get_logits(self, sample, net_output):
|
|
185
|
+
logits = net_output["x"]
|
|
186
|
+
logits = logits.transpose(0, 2)
|
|
187
|
+
logits = logits.reshape(-1, logits.size(-1))
|
|
188
|
+
|
|
189
|
+
return logits
|
|
190
|
+
|
|
191
|
+
def get_targets(self, sample, net_output):
|
|
192
|
+
x = net_output["x"]
|
|
193
|
+
return x.new_zeros(x.size(1) * x.size(2), dtype=torch.long)
|
|
194
|
+
|
|
195
|
+
def get_extra_losses(self, net_output):
|
|
196
|
+
pen = []
|
|
197
|
+
|
|
198
|
+
if "prob_perplexity" in net_output:
|
|
199
|
+
pen.append((net_output["num_vars"] - net_output["prob_perplexity"]) / net_output["num_vars"])
|
|
200
|
+
|
|
201
|
+
if "features_pen" in net_output:
|
|
202
|
+
pen.append(net_output["features_pen"])
|
|
203
|
+
|
|
204
|
+
return pen
|
|
205
|
+
|
|
206
|
+
def forward(
|
|
207
|
+
self,
|
|
208
|
+
input_ids: torch.Tensor,
|
|
209
|
+
type_ids: torch.Tensor = None,
|
|
210
|
+
dpe_ids: torch.Tensor = None,
|
|
211
|
+
padding_mask: torch.Tensor = None,
|
|
212
|
+
**kwargs,
|
|
213
|
+
):
|
|
214
|
+
if self.feature_grad_mult > 0:
|
|
215
|
+
features_ret = super().forward(
|
|
216
|
+
input_ids=input_ids,
|
|
217
|
+
type_ids=type_ids,
|
|
218
|
+
dpe_ids=dpe_ids,
|
|
219
|
+
padding_mask=padding_mask,
|
|
220
|
+
encoder_only=True,
|
|
221
|
+
**kwargs,
|
|
222
|
+
)
|
|
223
|
+
features = features_ret["x"]
|
|
224
|
+
padding_mask = features_ret["padding_mask"]
|
|
225
|
+
if self.feature_grad_mult != 1.0:
|
|
226
|
+
features = GradMultiply.apply(features, self.feature_grad_mult)
|
|
227
|
+
else:
|
|
228
|
+
with torch.no_grad():
|
|
229
|
+
features = super().forward(
|
|
230
|
+
input_ids=input_ids,
|
|
231
|
+
type_ids=type_ids,
|
|
232
|
+
dpe_ids=dpe_ids,
|
|
233
|
+
padding_mask=padding_mask,
|
|
234
|
+
encoder_only=True,
|
|
235
|
+
**kwargs,
|
|
236
|
+
)
|
|
237
|
+
features = features_ret["x"]
|
|
238
|
+
padding_mask = features_ret["padding_mask"]
|
|
239
|
+
|
|
240
|
+
features_pen = features.float().pow(2).mean()
|
|
241
|
+
|
|
242
|
+
features = self.layer_norm(features)
|
|
243
|
+
unmasked_features = features.clone()
|
|
244
|
+
|
|
245
|
+
features = self.dropout_input(features)
|
|
246
|
+
unmasked_features = self.dropout_features(unmasked_features)
|
|
247
|
+
|
|
248
|
+
num_vars = None
|
|
249
|
+
code_ppl = None
|
|
250
|
+
prob_ppl = None
|
|
251
|
+
curr_temp = None
|
|
252
|
+
|
|
253
|
+
if padding_mask is None:
|
|
254
|
+
padding_mask = input_ids[:, :, 1].eq(0).to(features.device)
|
|
255
|
+
|
|
256
|
+
x, mask_indices = self.apply_mask(features, padding_mask, mask_indices=None)
|
|
257
|
+
if mask_indices is not None:
|
|
258
|
+
y = unmasked_features[mask_indices].view(
|
|
259
|
+
unmasked_features.size(0), -1, unmasked_features.size(-1)
|
|
260
|
+
)
|
|
261
|
+
else:
|
|
262
|
+
y = unmasked_features
|
|
263
|
+
|
|
264
|
+
x = self.event_aggregator(x, src_key_padding_mask=padding_mask)
|
|
265
|
+
|
|
266
|
+
q = self.quantizer(y, produce_targets=False)
|
|
267
|
+
y = q["x"]
|
|
268
|
+
num_vars = q["num_vars"]
|
|
269
|
+
code_ppl = q["code_perplexity"]
|
|
270
|
+
prob_ppl = q["prob_perplexity"]
|
|
271
|
+
curr_temp = q["temp"]
|
|
272
|
+
|
|
273
|
+
y = self.project_q(y)
|
|
274
|
+
|
|
275
|
+
negs, _ = self.sample_negatives(y, y.size(1))
|
|
276
|
+
|
|
277
|
+
if self.codebook_negatives > 0:
|
|
278
|
+
cb_negs = self.quantizer.sample_from_codebook(y.size(0) * y.size(1), self.codebook_negatives)
|
|
279
|
+
cb_negs = cb_negs.view(self.codebook_negatives, y.size(0), y.size(1), -1)
|
|
280
|
+
cb_negs = self.project_q(cb_negs)
|
|
281
|
+
negs = torch.cat([negs, cb_negs], dim=0)
|
|
282
|
+
|
|
283
|
+
x = x[mask_indices].view(x.size(0), -1, x.size(-1))
|
|
284
|
+
|
|
285
|
+
x = self.final_proj(x)
|
|
286
|
+
x = self.compute_preds(x, y, negs)
|
|
287
|
+
|
|
288
|
+
output = {"x": x, "features_pen": features_pen, "mask_indices": mask_indices}
|
|
289
|
+
if prob_ppl is not None:
|
|
290
|
+
output["prob_perplexity"] = prob_ppl
|
|
291
|
+
output["code_perplexity"] = code_ppl
|
|
292
|
+
output["num_vars"] = num_vars
|
|
293
|
+
output["temp"] = curr_temp
|
|
294
|
+
|
|
295
|
+
return output
|
|
296
|
+
|
|
297
|
+
def get_pretraining_parameter_names(self):
|
|
298
|
+
ret = []
|
|
299
|
+
ret.append("mask_emb")
|
|
300
|
+
ret.extend(["quantizer" + "." + x[0] for x in self.quantizer.named_parameters()])
|
|
301
|
+
ret.extend(["layer_norm" + "." + x[0] for x in self.layer_norm.named_parameters()])
|
|
302
|
+
ret.extend(["project_q" + "." + x[0] for x in self.project_q.named_parameters()])
|
|
303
|
+
ret.extend(["final_proj" + "." + x[0] for x in self.final_proj.named_parameters()])
|
|
304
|
+
return ret
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from .gather_layer import GatherLayer
|
|
2
|
+
from .grad_multiply import GradMultiply
|
|
3
|
+
from .gumbel_vector_quantizer import GumbelVectorQuantizer
|
|
4
|
+
from .identity_layer import Identity
|
|
5
|
+
from .layer_norm import LayerNorm
|
|
6
|
+
from .positional_encoding import PositionalEncoding
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"Identity",
|
|
10
|
+
"GatherLayer",
|
|
11
|
+
"GradMultiply",
|
|
12
|
+
"GumbelVectorQuantizer",
|
|
13
|
+
"LayerNorm",
|
|
14
|
+
"PositionalEncoding",
|
|
15
|
+
]
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.distributed as dist
|
|
3
|
+
|
|
4
|
+
import genhpf.utils.distributed_utils as dist_utils
|
|
5
|
+
|
|
6
|
+
class GatherLayer(torch.autograd.Function):
|
|
7
|
+
"""Gather tensors from all process, supporting backward propagation."""
|
|
8
|
+
|
|
9
|
+
@staticmethod
|
|
10
|
+
def forward(ctx, input):
|
|
11
|
+
ctx.save_for_backward(input)
|
|
12
|
+
|
|
13
|
+
group = dist_utils.get_data_parallel_group()
|
|
14
|
+
output = dist_utils.batch_all_gather(input, group=group)
|
|
15
|
+
|
|
16
|
+
return tuple(output)
|
|
17
|
+
|
|
18
|
+
@staticmethod
|
|
19
|
+
def backward(ctx, *grads):
|
|
20
|
+
(input, ) = ctx.saved_tensors
|
|
21
|
+
grad_out = torch.zeros_like(input)
|
|
22
|
+
grad_out[:] = grads[dist.get_rank()]
|
|
23
|
+
return grad_out
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
|
|
5
|
+
class GumbelVectorQuantizer(nn.Module):
|
|
6
|
+
def __init__(
|
|
7
|
+
self,
|
|
8
|
+
dim,
|
|
9
|
+
num_vars,
|
|
10
|
+
temp,
|
|
11
|
+
groups,
|
|
12
|
+
combine_groups,
|
|
13
|
+
vq_dim,
|
|
14
|
+
time_first,
|
|
15
|
+
activation=nn.GELU(),
|
|
16
|
+
weight_proj_depth=1,
|
|
17
|
+
weight_proj_factor=1,
|
|
18
|
+
):
|
|
19
|
+
"""Vector quantization using gumbel softmax
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
dim: input dimension (channels)
|
|
23
|
+
num_vars: number of quantized vectors per group
|
|
24
|
+
temp: temperature for training. this should be a tuple of 3 elements: (start, stop, decay factor)
|
|
25
|
+
groups: number of groups for vector quantization
|
|
26
|
+
combine_groups: whether to use the vectors for all groups
|
|
27
|
+
vq_dim: dimensionality of the resulting quantized vector
|
|
28
|
+
time_first: if true, expect input in BxTxC format, otherwise in BxCxT
|
|
29
|
+
activation: what activation to use (should be a module). this is only used if weight_proj_depth is > 1
|
|
30
|
+
weight_proj_depth: number of layers (with activation in between) to project input before computing logits
|
|
31
|
+
weight_proj_factor: this is used only if weight_proj_depth is > 1. scales the inner dimensionality of
|
|
32
|
+
projections by this factor
|
|
33
|
+
"""
|
|
34
|
+
"""
|
|
35
|
+
num_vars : k
|
|
36
|
+
groups: number of codebooks (divide channels by ...)
|
|
37
|
+
combine_groups: whether to map the vectors for overall centroids (over codebooks) => false
|
|
38
|
+
"""
|
|
39
|
+
super().__init__()
|
|
40
|
+
|
|
41
|
+
self.groups = groups
|
|
42
|
+
self.combine_groups = combine_groups
|
|
43
|
+
self.input_dim = dim
|
|
44
|
+
self.num_vars = num_vars
|
|
45
|
+
self.time_first = time_first
|
|
46
|
+
self.num_updates = 0
|
|
47
|
+
|
|
48
|
+
assert (
|
|
49
|
+
vq_dim % groups == 0
|
|
50
|
+
), f"dim {vq_dim} must be divisible by groups {groups} for concatenation"
|
|
51
|
+
|
|
52
|
+
var_dim = vq_dim // groups
|
|
53
|
+
num_groups = groups if not combine_groups else 1
|
|
54
|
+
|
|
55
|
+
self.vars = nn.Parameter(torch.FloatTensor(1, num_groups * num_vars, var_dim))
|
|
56
|
+
nn.init.uniform_(self.vars)
|
|
57
|
+
|
|
58
|
+
if weight_proj_depth > 1:
|
|
59
|
+
|
|
60
|
+
def block(input_dim, output_dim):
|
|
61
|
+
return nn.Sequential(nn.Linear(input_dim, output_dim), activation)
|
|
62
|
+
|
|
63
|
+
inner_dim = self.input_dim * weight_proj_factor
|
|
64
|
+
self.weight_proj = nn.Sequential(
|
|
65
|
+
*[
|
|
66
|
+
block(self.input_dim if i == 0 else inner_dim, inner_dim)
|
|
67
|
+
for i in range(weight_proj_depth - 1)
|
|
68
|
+
],
|
|
69
|
+
nn.Linear(inner_dim, groups * num_vars),
|
|
70
|
+
)
|
|
71
|
+
else:
|
|
72
|
+
self.weight_proj = nn.Linear(self.input_dim, groups * num_vars)
|
|
73
|
+
nn.init.normal_(self.weight_proj.weight, mean=0, std=1)
|
|
74
|
+
nn.init.zeros_(self.weight_proj.bias)
|
|
75
|
+
|
|
76
|
+
if isinstance(temp, str):
|
|
77
|
+
import ast
|
|
78
|
+
temp = ast.literal_eval(temp)
|
|
79
|
+
assert len(temp) == 3, f"{temp}, {len(temp)}"
|
|
80
|
+
|
|
81
|
+
self.max_temp, self.min_temp, self.temp_decay = temp
|
|
82
|
+
self.curr_temp = self.max_temp
|
|
83
|
+
self.codebook_indices = None
|
|
84
|
+
|
|
85
|
+
def set_num_updates(self, num_updates):
|
|
86
|
+
self.curr_temp = max(
|
|
87
|
+
self.max_temp * self.temp_decay ** num_updates, self.min_temp
|
|
88
|
+
)
|
|
89
|
+
self.num_updates = num_updates
|
|
90
|
+
|
|
91
|
+
def get_codebook_indices(self):
|
|
92
|
+
if self.codebook_indices is None:
|
|
93
|
+
from itertools import product
|
|
94
|
+
|
|
95
|
+
p = [range(self.num_vars)] * self.groups
|
|
96
|
+
inds = list(product(*p))
|
|
97
|
+
self.codebook_indices = torch.tensor(
|
|
98
|
+
inds, dtype=torch.long, device=self.vars.device
|
|
99
|
+
).flatten()
|
|
100
|
+
|
|
101
|
+
if not self.combine_groups:
|
|
102
|
+
self.codebook_indices = self.codebook_indices.view(
|
|
103
|
+
self.num_vars ** self.groups, -1
|
|
104
|
+
)
|
|
105
|
+
for b in range(1, self.groups):
|
|
106
|
+
self.codebook_indices[:, b] += self.num_vars * b
|
|
107
|
+
self.codebook_indices = self.codebook_indices.flatten()
|
|
108
|
+
return self.codebook_indices
|
|
109
|
+
|
|
110
|
+
def codebook(self):
|
|
111
|
+
indices = self.get_codebook_indices()
|
|
112
|
+
return (
|
|
113
|
+
self.vars.squeeze(0)
|
|
114
|
+
.index_select(0, indices)
|
|
115
|
+
.view(self.num_vars ** self.groups, -1)
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
def sample_from_codebook(self, b, n):
|
|
119
|
+
indices = self.get_codebook_indices()
|
|
120
|
+
indices = indices.view(-1, self.groups)
|
|
121
|
+
cb_size = indices.size(0)
|
|
122
|
+
assert (
|
|
123
|
+
n < cb_size
|
|
124
|
+
), f"sample size {n} is greater than size of codebook {cb_size}"
|
|
125
|
+
sample_idx = torch.randint(low=0, high=cb_size, size=(b * n,))
|
|
126
|
+
indices = indices[sample_idx]
|
|
127
|
+
|
|
128
|
+
z = self.vars.squeeze(0).index_select(0, indices.flatten()).view(b, n, -1)
|
|
129
|
+
return z
|
|
130
|
+
|
|
131
|
+
def to_codebook_index(self, indices):
|
|
132
|
+
res = indices.new_full(indices.shape[:-1], 0)
|
|
133
|
+
for i in range(self.groups):
|
|
134
|
+
exponent = self.groups - i - 1
|
|
135
|
+
res += indices[..., i] * (self.num_vars ** exponent)
|
|
136
|
+
return res
|
|
137
|
+
|
|
138
|
+
def forward_idx(self, x):
|
|
139
|
+
res = self.forward(x, produce_targets=True)
|
|
140
|
+
return res["x"], res["targets"]
|
|
141
|
+
|
|
142
|
+
def forward(self, x, produce_targets=False):
|
|
143
|
+
|
|
144
|
+
result = {"num_vars": self.num_vars * self.groups}
|
|
145
|
+
if not self.time_first:
|
|
146
|
+
x = x.transpose(1, 2)
|
|
147
|
+
|
|
148
|
+
bsz, tsz, fsz = x.shape
|
|
149
|
+
x = x.reshape(-1, fsz)
|
|
150
|
+
x = self.weight_proj(x)
|
|
151
|
+
x = x.view(bsz * tsz * self.groups, -1)
|
|
152
|
+
|
|
153
|
+
_, k = x.max(-1)
|
|
154
|
+
|
|
155
|
+
hard_x = (
|
|
156
|
+
x.new_zeros(*x.shape)
|
|
157
|
+
.scatter_(-1, k.view(-1, 1), 1.0)
|
|
158
|
+
.view(bsz * tsz, self.groups, -1)
|
|
159
|
+
)
|
|
160
|
+
hard_probs = torch.mean(hard_x.float(), dim=0)
|
|
161
|
+
result["code_perplexity"] = torch.exp(
|
|
162
|
+
-torch.sum(hard_probs * torch.log(hard_probs + 1e-7), dim=-1)
|
|
163
|
+
).sum()
|
|
164
|
+
|
|
165
|
+
avg_probs = torch.softmax(
|
|
166
|
+
x.view(bsz * tsz, self.groups, -1).float(), dim=-1
|
|
167
|
+
).mean(dim=0)
|
|
168
|
+
result["prob_perplexity"] = torch.exp(
|
|
169
|
+
-torch.sum(avg_probs * torch.log(avg_probs + 1e-7), dim=-1)
|
|
170
|
+
).sum()
|
|
171
|
+
|
|
172
|
+
result["temp"] = self.curr_temp
|
|
173
|
+
|
|
174
|
+
if self.training:
|
|
175
|
+
x = F.gumbel_softmax(x.float(), tau=self.curr_temp, hard=True).type_as(x)
|
|
176
|
+
else:
|
|
177
|
+
x = hard_x
|
|
178
|
+
|
|
179
|
+
x = x.view(bsz * tsz, -1)
|
|
180
|
+
|
|
181
|
+
vars = self.vars
|
|
182
|
+
|
|
183
|
+
if self.combine_groups:
|
|
184
|
+
vars = vars.repeat(1, self.groups, 1)
|
|
185
|
+
|
|
186
|
+
if produce_targets:
|
|
187
|
+
result["targets"] = (
|
|
188
|
+
x.view(bsz * tsz * self.groups, -1)
|
|
189
|
+
.argmax(dim=-1)
|
|
190
|
+
.view(bsz, tsz, self.groups)
|
|
191
|
+
.detach()
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
x = x.unsqueeze(-1) * vars
|
|
195
|
+
x = x.view(bsz * tsz, self.groups, self.num_vars, -1)
|
|
196
|
+
x = x.sum(-2)
|
|
197
|
+
x = x.view(bsz, tsz, -1)
|
|
198
|
+
|
|
199
|
+
if not self.time_first:
|
|
200
|
+
x = x.transpose(1, 2) # BTC -> BCT
|
|
201
|
+
|
|
202
|
+
result["x"] = x
|
|
203
|
+
|
|
204
|
+
return result
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
|
|
5
|
+
has_fused_layernorm = False
|
|
6
|
+
|
|
7
|
+
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
|
|
8
|
+
if torch.jit.is_scripting():
|
|
9
|
+
export = True
|
|
10
|
+
if not export and torch.cuda.is_available() and has_fused_layernorm:
|
|
11
|
+
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
|
|
12
|
+
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Fp32LayerNorm(nn.LayerNorm):
|
|
16
|
+
def __init__(self, *args, **kwargs):
|
|
17
|
+
super().__init__(*args, **kwargs)
|
|
18
|
+
|
|
19
|
+
def forward(self, input):
|
|
20
|
+
output = F.layer_norm(
|
|
21
|
+
input.float(),
|
|
22
|
+
self.normalized_shape,
|
|
23
|
+
self.weight.float() if self.weight is not None else None,
|
|
24
|
+
self.bias.float() if self.bias is not None else None,
|
|
25
|
+
self.eps,
|
|
26
|
+
)
|
|
27
|
+
return output.type_as(input)
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
|
|
6
|
+
class PositionalEncoding(nn.Module):
|
|
7
|
+
def __init__(self, d_model, dropout, max_len):
|
|
8
|
+
super().__init__()
|
|
9
|
+
self.dropout = nn.Dropout(p=dropout)
|
|
10
|
+
position = torch.arange(max_len).unsqueeze(1)
|
|
11
|
+
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
|
12
|
+
pe = torch.zeros(1, max_len, d_model)
|
|
13
|
+
pe[0, :, 0::2] = torch.sin(position * div_term)
|
|
14
|
+
pe[0, :, 1::2] = torch.cos(position * div_term)
|
|
15
|
+
self.register_buffer("pe", pe)
|
|
16
|
+
|
|
17
|
+
def forward(self, x):
|
|
18
|
+
"""
|
|
19
|
+
Args:
|
|
20
|
+
x: Tensor, shape [batch_size, seq_len, embedding_dim]
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
x = x + self.pe[:, :x.size(1)]
|
|
24
|
+
return self.dropout(x)
|
|
File without changes
|
|
File without changes
|