genhpf 1.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. genhpf/__init__.py +9 -0
  2. genhpf/configs/__init__.py +23 -0
  3. genhpf/configs/config.yaml +8 -0
  4. genhpf/configs/configs.py +240 -0
  5. genhpf/configs/constants.py +29 -0
  6. genhpf/configs/initialize.py +58 -0
  7. genhpf/configs/utils.py +29 -0
  8. genhpf/criterions/__init__.py +74 -0
  9. genhpf/criterions/binary_cross_entropy.py +114 -0
  10. genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
  11. genhpf/criterions/criterion.py +87 -0
  12. genhpf/criterions/cross_entropy.py +202 -0
  13. genhpf/criterions/multi_task_criterion.py +177 -0
  14. genhpf/criterions/simclr_criterion.py +84 -0
  15. genhpf/criterions/wav2vec2_criterion.py +130 -0
  16. genhpf/datasets/__init__.py +84 -0
  17. genhpf/datasets/dataset.py +109 -0
  18. genhpf/datasets/genhpf_dataset.py +451 -0
  19. genhpf/datasets/meds_dataset.py +232 -0
  20. genhpf/loggings/__init__.py +0 -0
  21. genhpf/loggings/meters.py +374 -0
  22. genhpf/loggings/metrics.py +155 -0
  23. genhpf/loggings/progress_bar.py +445 -0
  24. genhpf/models/__init__.py +73 -0
  25. genhpf/models/genhpf.py +244 -0
  26. genhpf/models/genhpf_mlm.py +64 -0
  27. genhpf/models/genhpf_predictor.py +73 -0
  28. genhpf/models/genhpf_simclr.py +58 -0
  29. genhpf/models/genhpf_wav2vec2.py +304 -0
  30. genhpf/modules/__init__.py +15 -0
  31. genhpf/modules/gather_layer.py +23 -0
  32. genhpf/modules/grad_multiply.py +12 -0
  33. genhpf/modules/gumbel_vector_quantizer.py +204 -0
  34. genhpf/modules/identity_layer.py +8 -0
  35. genhpf/modules/layer_norm.py +27 -0
  36. genhpf/modules/positional_encoding.py +24 -0
  37. genhpf/scripts/__init__.py +0 -0
  38. genhpf/scripts/preprocess/__init__.py +0 -0
  39. genhpf/scripts/preprocess/genhpf/README.md +75 -0
  40. genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
  41. genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
  42. genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
  43. genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
  44. genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
  45. genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
  46. genhpf/scripts/preprocess/genhpf/main.py +175 -0
  47. genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
  48. genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
  49. genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
  50. genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
  51. genhpf/scripts/preprocess/manifest.py +83 -0
  52. genhpf/scripts/preprocess/preprocess_meds.py +674 -0
  53. genhpf/scripts/test.py +264 -0
  54. genhpf/scripts/train.py +365 -0
  55. genhpf/trainer.py +370 -0
  56. genhpf/utils/checkpoint_utils.py +171 -0
  57. genhpf/utils/data_utils.py +130 -0
  58. genhpf/utils/distributed_utils.py +497 -0
  59. genhpf/utils/file_io.py +170 -0
  60. genhpf/utils/pdb.py +38 -0
  61. genhpf/utils/utils.py +204 -0
  62. genhpf-1.0.11.dist-info/LICENSE +21 -0
  63. genhpf-1.0.11.dist-info/METADATA +202 -0
  64. genhpf-1.0.11.dist-info/RECORD +67 -0
  65. genhpf-1.0.11.dist-info/WHEEL +5 -0
  66. genhpf-1.0.11.dist-info/entry_points.txt +6 -0
  67. genhpf-1.0.11.dist-info/top_level.txt +1 -0
@@ -0,0 +1,304 @@
1
+ import logging
2
+ from dataclasses import dataclass, field
3
+ from typing import Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from genhpf.models import register_model
9
+ from genhpf.models.genhpf import GenHPF, GenHPFConfig
10
+ from genhpf.modules import GradMultiply, GumbelVectorQuantizer, LayerNorm
11
+ from genhpf.utils import utils
12
+ from genhpf.utils.data_utils import compute_mask_indices
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ @dataclass
18
+ class GenHPFWav2Vec2Config(GenHPFConfig):
19
+ logit_temp: float = field(default=0.1, metadata={"help": "temperature to divide logits by"})
20
+ latent_vars: int = field(
21
+ default=320, metadata={"help": "number of latent variables V in each group of the codebook"}
22
+ )
23
+ latent_groups: int = field(
24
+ default=2, metadata={"help": "number of groups G of latent variables in the codebook"}
25
+ )
26
+ latent_temp: Tuple[float, float, float] = field(
27
+ default=(2, 0.5, 0.999995),
28
+ metadata={
29
+ "help": "temperature for latent variable sampling. "
30
+ "can be tuple of 3 values (start, end, decay)"
31
+ },
32
+ )
33
+ final_dim: int = field(
34
+ default=128, metadata={"help": "project final representations and targets to this many dimensions."}
35
+ )
36
+
37
+ num_negatives: int = field(
38
+ default=25, metadata={"help": "number of negative examples from the same sample"}
39
+ )
40
+ codebook_negatives: int = field(default=0, metadata={"help": "number of negative examples codebook"})
41
+
42
+ # mask
43
+ mask_prob: float = field(default=0.65, metadata={"help": "probability of replacing a token with mask"})
44
+ mask_length: int = field(default=1, metadata={"help": "mask length"})
45
+ no_mask_overlap: bool = field(default=False, metadata={"help": "whether to allow masks to overlap"})
46
+ mask_min_space: int = field(
47
+ default=0, metadata={"help": "min space between spans (if no overlap is enabled)"}
48
+ )
49
+
50
+ feature_grad_mult: float = field(
51
+ default=0.1, metadata={"help": "multiply event encoder gradients by this"}
52
+ )
53
+ dropout_input: float = field(default=0.1, metadata={"help": "dropout to apply to the input"})
54
+ dropout_features: float = field(default=0.1, metadata={"help": "dropout to apply to the features"})
55
+
56
+
57
+ @register_model("genhpf_wav2vec2", dataclass=GenHPFWav2Vec2Config)
58
+ class GenHPFWav2Vec2(GenHPF):
59
+ def __init__(self, cfg: GenHPFWav2Vec2Config):
60
+ super().__init__(cfg)
61
+
62
+ self.logit_temp = cfg.logit_temp
63
+
64
+ self.mask_prob = cfg.mask_prob
65
+ self.mask_length = cfg.mask_length
66
+ self.no_mask_overlap = cfg.no_mask_overlap
67
+ self.mask_min_space = cfg.mask_min_space
68
+
69
+ self.feature_grad_mult = cfg.feature_grad_mult
70
+
71
+ self.codebook_negatives = cfg.codebook_negatives
72
+
73
+ self.quantizer = GumbelVectorQuantizer(
74
+ dim=cfg.agg_embed_dim,
75
+ num_vars=cfg.latent_vars,
76
+ temp=cfg.latent_temp,
77
+ groups=cfg.latent_groups,
78
+ combine_groups=False,
79
+ vq_dim=cfg.agg_embed_dim,
80
+ time_first=True,
81
+ )
82
+
83
+ self.layer_norm = LayerNorm(cfg.agg_embed_dim)
84
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
85
+ self.dropout_features = nn.Dropout(cfg.dropout_features)
86
+ self.project_q = nn.Linear(cfg.agg_embed_dim, cfg.final_dim)
87
+ self.final_proj = nn.Linear(cfg.agg_embed_dim, cfg.final_dim)
88
+ self.mask_emb = nn.Parameter(torch.FloatTensor(cfg.agg_embed_dim).uniform_())
89
+ self.cross_sample_negatives = 0
90
+ self.n_negatives = cfg.num_negatives
91
+
92
+ @classmethod
93
+ def build_model(cls, cfg):
94
+ """Build a new model instance."""
95
+ return cls(cfg)
96
+
97
+ def compute_preds(self, x, y, negatives):
98
+ neg_is_pos = (y == negatives).all(-1)
99
+ y = y.unsqueeze(0)
100
+ targets = torch.cat([y, negatives], dim=0)
101
+
102
+ logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1)
103
+
104
+ logits = logits / self.logit_temp
105
+ logits = logits.type_as(x)
106
+
107
+ if neg_is_pos.any():
108
+ logits[1:][neg_is_pos] = float("-inf")
109
+
110
+ return logits
111
+
112
+ def apply_mask(
113
+ self,
114
+ x: torch.Tensor,
115
+ padding_mask: torch.Tensor,
116
+ mask_indices=None,
117
+ ):
118
+ bsz, tsz, csz = x.shape
119
+
120
+ if self.mask_prob > 0:
121
+ if mask_indices is None:
122
+ mask_indices = compute_mask_indices(
123
+ shape=(bsz, tsz),
124
+ padding_mask=padding_mask,
125
+ mask_prob=self.mask_prob,
126
+ mask_length=self.mask_length,
127
+ mask_type="static",
128
+ min_masks=2,
129
+ no_overlap=self.no_mask_overlap,
130
+ min_space=self.mask_min_space,
131
+ )
132
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
133
+ x[mask_indices] = self.mask_emb
134
+ else:
135
+ mask_indices = None
136
+
137
+ return x, mask_indices
138
+
139
+ def sample_negatives(self, y, num):
140
+ if self.n_negatives == 0 and self.cross_sample_negatives == 0:
141
+ return y.new(0)
142
+
143
+ batch_size, time_size, feature_size = y.shape
144
+ y = y.view(-1, feature_size) # B x T x C -> (B x T) x C
145
+
146
+ cross_high = time_size * batch_size
147
+ high = time_size
148
+ with torch.no_grad():
149
+ assert high > 1, f"{batch_size, time_size, feature_size}"
150
+
151
+ if self.n_negatives > 0:
152
+ time_sizes = utils.buffered_arange(num).unsqueeze(-1).expand(-1, self.n_negatives).flatten()
153
+ neg_idxs = torch.randint(low=0, high=high - 1, size=(batch_size, self.n_negatives * num))
154
+ neg_idxs[neg_idxs >= time_sizes] += 1
155
+
156
+ if self.cross_sample_negatives > 0:
157
+ time_sizes = (
158
+ utils.buffered_arange(num).unsqueeze(-1).expand(-1, self.cross_sample_negatives).flatten()
159
+ )
160
+
161
+ cross_neg_idxs = torch.randint(
162
+ low=0, high=cross_high - 1, size=(batch_size, self.cross_sample_negatives * num)
163
+ )
164
+ cross_neg_idxs[cross_neg_idxs >= time_sizes] += 1
165
+
166
+ if self.n_negatives > 0:
167
+ for i in range(1, batch_size):
168
+ neg_idxs[i] += i * high
169
+ else:
170
+ neg_idxs = cross_neg_idxs
171
+
172
+ if self.cross_sample_negatives > 0 and self.n_negatives > 0:
173
+ neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1)
174
+
175
+ negs = y[neg_idxs.view(-1)]
176
+ negs = negs.view(
177
+ batch_size, num, self.n_negatives + self.cross_sample_negatives, feature_size
178
+ ).permute(
179
+ 2, 0, 1, 3
180
+ ) # to N x B x T x C
181
+
182
+ return negs, neg_idxs
183
+
184
+ def get_logits(self, sample, net_output):
185
+ logits = net_output["x"]
186
+ logits = logits.transpose(0, 2)
187
+ logits = logits.reshape(-1, logits.size(-1))
188
+
189
+ return logits
190
+
191
+ def get_targets(self, sample, net_output):
192
+ x = net_output["x"]
193
+ return x.new_zeros(x.size(1) * x.size(2), dtype=torch.long)
194
+
195
+ def get_extra_losses(self, net_output):
196
+ pen = []
197
+
198
+ if "prob_perplexity" in net_output:
199
+ pen.append((net_output["num_vars"] - net_output["prob_perplexity"]) / net_output["num_vars"])
200
+
201
+ if "features_pen" in net_output:
202
+ pen.append(net_output["features_pen"])
203
+
204
+ return pen
205
+
206
+ def forward(
207
+ self,
208
+ input_ids: torch.Tensor,
209
+ type_ids: torch.Tensor = None,
210
+ dpe_ids: torch.Tensor = None,
211
+ padding_mask: torch.Tensor = None,
212
+ **kwargs,
213
+ ):
214
+ if self.feature_grad_mult > 0:
215
+ features_ret = super().forward(
216
+ input_ids=input_ids,
217
+ type_ids=type_ids,
218
+ dpe_ids=dpe_ids,
219
+ padding_mask=padding_mask,
220
+ encoder_only=True,
221
+ **kwargs,
222
+ )
223
+ features = features_ret["x"]
224
+ padding_mask = features_ret["padding_mask"]
225
+ if self.feature_grad_mult != 1.0:
226
+ features = GradMultiply.apply(features, self.feature_grad_mult)
227
+ else:
228
+ with torch.no_grad():
229
+ features = super().forward(
230
+ input_ids=input_ids,
231
+ type_ids=type_ids,
232
+ dpe_ids=dpe_ids,
233
+ padding_mask=padding_mask,
234
+ encoder_only=True,
235
+ **kwargs,
236
+ )
237
+ features = features_ret["x"]
238
+ padding_mask = features_ret["padding_mask"]
239
+
240
+ features_pen = features.float().pow(2).mean()
241
+
242
+ features = self.layer_norm(features)
243
+ unmasked_features = features.clone()
244
+
245
+ features = self.dropout_input(features)
246
+ unmasked_features = self.dropout_features(unmasked_features)
247
+
248
+ num_vars = None
249
+ code_ppl = None
250
+ prob_ppl = None
251
+ curr_temp = None
252
+
253
+ if padding_mask is None:
254
+ padding_mask = input_ids[:, :, 1].eq(0).to(features.device)
255
+
256
+ x, mask_indices = self.apply_mask(features, padding_mask, mask_indices=None)
257
+ if mask_indices is not None:
258
+ y = unmasked_features[mask_indices].view(
259
+ unmasked_features.size(0), -1, unmasked_features.size(-1)
260
+ )
261
+ else:
262
+ y = unmasked_features
263
+
264
+ x = self.event_aggregator(x, src_key_padding_mask=padding_mask)
265
+
266
+ q = self.quantizer(y, produce_targets=False)
267
+ y = q["x"]
268
+ num_vars = q["num_vars"]
269
+ code_ppl = q["code_perplexity"]
270
+ prob_ppl = q["prob_perplexity"]
271
+ curr_temp = q["temp"]
272
+
273
+ y = self.project_q(y)
274
+
275
+ negs, _ = self.sample_negatives(y, y.size(1))
276
+
277
+ if self.codebook_negatives > 0:
278
+ cb_negs = self.quantizer.sample_from_codebook(y.size(0) * y.size(1), self.codebook_negatives)
279
+ cb_negs = cb_negs.view(self.codebook_negatives, y.size(0), y.size(1), -1)
280
+ cb_negs = self.project_q(cb_negs)
281
+ negs = torch.cat([negs, cb_negs], dim=0)
282
+
283
+ x = x[mask_indices].view(x.size(0), -1, x.size(-1))
284
+
285
+ x = self.final_proj(x)
286
+ x = self.compute_preds(x, y, negs)
287
+
288
+ output = {"x": x, "features_pen": features_pen, "mask_indices": mask_indices}
289
+ if prob_ppl is not None:
290
+ output["prob_perplexity"] = prob_ppl
291
+ output["code_perplexity"] = code_ppl
292
+ output["num_vars"] = num_vars
293
+ output["temp"] = curr_temp
294
+
295
+ return output
296
+
297
+ def get_pretraining_parameter_names(self):
298
+ ret = []
299
+ ret.append("mask_emb")
300
+ ret.extend(["quantizer" + "." + x[0] for x in self.quantizer.named_parameters()])
301
+ ret.extend(["layer_norm" + "." + x[0] for x in self.layer_norm.named_parameters()])
302
+ ret.extend(["project_q" + "." + x[0] for x in self.project_q.named_parameters()])
303
+ ret.extend(["final_proj" + "." + x[0] for x in self.final_proj.named_parameters()])
304
+ return ret
@@ -0,0 +1,15 @@
1
+ from .gather_layer import GatherLayer
2
+ from .grad_multiply import GradMultiply
3
+ from .gumbel_vector_quantizer import GumbelVectorQuantizer
4
+ from .identity_layer import Identity
5
+ from .layer_norm import LayerNorm
6
+ from .positional_encoding import PositionalEncoding
7
+
8
+ __all__ = [
9
+ "Identity",
10
+ "GatherLayer",
11
+ "GradMultiply",
12
+ "GumbelVectorQuantizer",
13
+ "LayerNorm",
14
+ "PositionalEncoding",
15
+ ]
@@ -0,0 +1,23 @@
1
+ import torch
2
+ import torch.distributed as dist
3
+
4
+ import genhpf.utils.distributed_utils as dist_utils
5
+
6
+ class GatherLayer(torch.autograd.Function):
7
+ """Gather tensors from all process, supporting backward propagation."""
8
+
9
+ @staticmethod
10
+ def forward(ctx, input):
11
+ ctx.save_for_backward(input)
12
+
13
+ group = dist_utils.get_data_parallel_group()
14
+ output = dist_utils.batch_all_gather(input, group=group)
15
+
16
+ return tuple(output)
17
+
18
+ @staticmethod
19
+ def backward(ctx, *grads):
20
+ (input, ) = ctx.saved_tensors
21
+ grad_out = torch.zeros_like(input)
22
+ grad_out[:] = grads[dist.get_rank()]
23
+ return grad_out
@@ -0,0 +1,12 @@
1
+ import torch
2
+
3
+ class GradMultiply(torch.autograd.Function):
4
+ @staticmethod
5
+ def forward(ctx, x, scale):
6
+ ctx.scale = scale
7
+ res = x.new(x)
8
+ return res
9
+
10
+ @staticmethod
11
+ def backward(ctx, grad):
12
+ return grad * ctx.scale, None
@@ -0,0 +1,204 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class GumbelVectorQuantizer(nn.Module):
6
+ def __init__(
7
+ self,
8
+ dim,
9
+ num_vars,
10
+ temp,
11
+ groups,
12
+ combine_groups,
13
+ vq_dim,
14
+ time_first,
15
+ activation=nn.GELU(),
16
+ weight_proj_depth=1,
17
+ weight_proj_factor=1,
18
+ ):
19
+ """Vector quantization using gumbel softmax
20
+
21
+ Args:
22
+ dim: input dimension (channels)
23
+ num_vars: number of quantized vectors per group
24
+ temp: temperature for training. this should be a tuple of 3 elements: (start, stop, decay factor)
25
+ groups: number of groups for vector quantization
26
+ combine_groups: whether to use the vectors for all groups
27
+ vq_dim: dimensionality of the resulting quantized vector
28
+ time_first: if true, expect input in BxTxC format, otherwise in BxCxT
29
+ activation: what activation to use (should be a module). this is only used if weight_proj_depth is > 1
30
+ weight_proj_depth: number of layers (with activation in between) to project input before computing logits
31
+ weight_proj_factor: this is used only if weight_proj_depth is > 1. scales the inner dimensionality of
32
+ projections by this factor
33
+ """
34
+ """
35
+ num_vars : k
36
+ groups: number of codebooks (divide channels by ...)
37
+ combine_groups: whether to map the vectors for overall centroids (over codebooks) => false
38
+ """
39
+ super().__init__()
40
+
41
+ self.groups = groups
42
+ self.combine_groups = combine_groups
43
+ self.input_dim = dim
44
+ self.num_vars = num_vars
45
+ self.time_first = time_first
46
+ self.num_updates = 0
47
+
48
+ assert (
49
+ vq_dim % groups == 0
50
+ ), f"dim {vq_dim} must be divisible by groups {groups} for concatenation"
51
+
52
+ var_dim = vq_dim // groups
53
+ num_groups = groups if not combine_groups else 1
54
+
55
+ self.vars = nn.Parameter(torch.FloatTensor(1, num_groups * num_vars, var_dim))
56
+ nn.init.uniform_(self.vars)
57
+
58
+ if weight_proj_depth > 1:
59
+
60
+ def block(input_dim, output_dim):
61
+ return nn.Sequential(nn.Linear(input_dim, output_dim), activation)
62
+
63
+ inner_dim = self.input_dim * weight_proj_factor
64
+ self.weight_proj = nn.Sequential(
65
+ *[
66
+ block(self.input_dim if i == 0 else inner_dim, inner_dim)
67
+ for i in range(weight_proj_depth - 1)
68
+ ],
69
+ nn.Linear(inner_dim, groups * num_vars),
70
+ )
71
+ else:
72
+ self.weight_proj = nn.Linear(self.input_dim, groups * num_vars)
73
+ nn.init.normal_(self.weight_proj.weight, mean=0, std=1)
74
+ nn.init.zeros_(self.weight_proj.bias)
75
+
76
+ if isinstance(temp, str):
77
+ import ast
78
+ temp = ast.literal_eval(temp)
79
+ assert len(temp) == 3, f"{temp}, {len(temp)}"
80
+
81
+ self.max_temp, self.min_temp, self.temp_decay = temp
82
+ self.curr_temp = self.max_temp
83
+ self.codebook_indices = None
84
+
85
+ def set_num_updates(self, num_updates):
86
+ self.curr_temp = max(
87
+ self.max_temp * self.temp_decay ** num_updates, self.min_temp
88
+ )
89
+ self.num_updates = num_updates
90
+
91
+ def get_codebook_indices(self):
92
+ if self.codebook_indices is None:
93
+ from itertools import product
94
+
95
+ p = [range(self.num_vars)] * self.groups
96
+ inds = list(product(*p))
97
+ self.codebook_indices = torch.tensor(
98
+ inds, dtype=torch.long, device=self.vars.device
99
+ ).flatten()
100
+
101
+ if not self.combine_groups:
102
+ self.codebook_indices = self.codebook_indices.view(
103
+ self.num_vars ** self.groups, -1
104
+ )
105
+ for b in range(1, self.groups):
106
+ self.codebook_indices[:, b] += self.num_vars * b
107
+ self.codebook_indices = self.codebook_indices.flatten()
108
+ return self.codebook_indices
109
+
110
+ def codebook(self):
111
+ indices = self.get_codebook_indices()
112
+ return (
113
+ self.vars.squeeze(0)
114
+ .index_select(0, indices)
115
+ .view(self.num_vars ** self.groups, -1)
116
+ )
117
+
118
+ def sample_from_codebook(self, b, n):
119
+ indices = self.get_codebook_indices()
120
+ indices = indices.view(-1, self.groups)
121
+ cb_size = indices.size(0)
122
+ assert (
123
+ n < cb_size
124
+ ), f"sample size {n} is greater than size of codebook {cb_size}"
125
+ sample_idx = torch.randint(low=0, high=cb_size, size=(b * n,))
126
+ indices = indices[sample_idx]
127
+
128
+ z = self.vars.squeeze(0).index_select(0, indices.flatten()).view(b, n, -1)
129
+ return z
130
+
131
+ def to_codebook_index(self, indices):
132
+ res = indices.new_full(indices.shape[:-1], 0)
133
+ for i in range(self.groups):
134
+ exponent = self.groups - i - 1
135
+ res += indices[..., i] * (self.num_vars ** exponent)
136
+ return res
137
+
138
+ def forward_idx(self, x):
139
+ res = self.forward(x, produce_targets=True)
140
+ return res["x"], res["targets"]
141
+
142
+ def forward(self, x, produce_targets=False):
143
+
144
+ result = {"num_vars": self.num_vars * self.groups}
145
+ if not self.time_first:
146
+ x = x.transpose(1, 2)
147
+
148
+ bsz, tsz, fsz = x.shape
149
+ x = x.reshape(-1, fsz)
150
+ x = self.weight_proj(x)
151
+ x = x.view(bsz * tsz * self.groups, -1)
152
+
153
+ _, k = x.max(-1)
154
+
155
+ hard_x = (
156
+ x.new_zeros(*x.shape)
157
+ .scatter_(-1, k.view(-1, 1), 1.0)
158
+ .view(bsz * tsz, self.groups, -1)
159
+ )
160
+ hard_probs = torch.mean(hard_x.float(), dim=0)
161
+ result["code_perplexity"] = torch.exp(
162
+ -torch.sum(hard_probs * torch.log(hard_probs + 1e-7), dim=-1)
163
+ ).sum()
164
+
165
+ avg_probs = torch.softmax(
166
+ x.view(bsz * tsz, self.groups, -1).float(), dim=-1
167
+ ).mean(dim=0)
168
+ result["prob_perplexity"] = torch.exp(
169
+ -torch.sum(avg_probs * torch.log(avg_probs + 1e-7), dim=-1)
170
+ ).sum()
171
+
172
+ result["temp"] = self.curr_temp
173
+
174
+ if self.training:
175
+ x = F.gumbel_softmax(x.float(), tau=self.curr_temp, hard=True).type_as(x)
176
+ else:
177
+ x = hard_x
178
+
179
+ x = x.view(bsz * tsz, -1)
180
+
181
+ vars = self.vars
182
+
183
+ if self.combine_groups:
184
+ vars = vars.repeat(1, self.groups, 1)
185
+
186
+ if produce_targets:
187
+ result["targets"] = (
188
+ x.view(bsz * tsz * self.groups, -1)
189
+ .argmax(dim=-1)
190
+ .view(bsz, tsz, self.groups)
191
+ .detach()
192
+ )
193
+
194
+ x = x.unsqueeze(-1) * vars
195
+ x = x.view(bsz * tsz, self.groups, self.num_vars, -1)
196
+ x = x.sum(-2)
197
+ x = x.view(bsz, tsz, -1)
198
+
199
+ if not self.time_first:
200
+ x = x.transpose(1, 2) # BTC -> BCT
201
+
202
+ result["x"] = x
203
+
204
+ return result
@@ -0,0 +1,8 @@
1
+ import torch.nn as nn
2
+
3
+ class Identity(nn.Module):
4
+ def __init__(self):
5
+ super(Identity, self).__init__()
6
+
7
+ def forward(self, x):
8
+ return x
@@ -0,0 +1,27 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ has_fused_layernorm = False
6
+
7
+ def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
8
+ if torch.jit.is_scripting():
9
+ export = True
10
+ if not export and torch.cuda.is_available() and has_fused_layernorm:
11
+ return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
12
+ return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
13
+
14
+
15
+ class Fp32LayerNorm(nn.LayerNorm):
16
+ def __init__(self, *args, **kwargs):
17
+ super().__init__(*args, **kwargs)
18
+
19
+ def forward(self, input):
20
+ output = F.layer_norm(
21
+ input.float(),
22
+ self.normalized_shape,
23
+ self.weight.float() if self.weight is not None else None,
24
+ self.bias.float() if self.bias is not None else None,
25
+ self.eps,
26
+ )
27
+ return output.type_as(input)
@@ -0,0 +1,24 @@
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ class PositionalEncoding(nn.Module):
7
+ def __init__(self, d_model, dropout, max_len):
8
+ super().__init__()
9
+ self.dropout = nn.Dropout(p=dropout)
10
+ position = torch.arange(max_len).unsqueeze(1)
11
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
12
+ pe = torch.zeros(1, max_len, d_model)
13
+ pe[0, :, 0::2] = torch.sin(position * div_term)
14
+ pe[0, :, 1::2] = torch.cos(position * div_term)
15
+ self.register_buffer("pe", pe)
16
+
17
+ def forward(self, x):
18
+ """
19
+ Args:
20
+ x: Tensor, shape [batch_size, seq_len, embedding_dim]
21
+ """
22
+
23
+ x = x + self.pe[:, :x.size(1)]
24
+ return self.dropout(x)
File without changes
File without changes