egogym 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- baselines/pi_policy.py +110 -0
- baselines/rum/__init__.py +1 -0
- baselines/rum/loss_fns/__init__.py +37 -0
- baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
- baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
- baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
- baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
- baselines/rum/models/__init__.py +1 -0
- baselines/rum/models/bet/__init__.py +3 -0
- baselines/rum/models/bet/bet.py +347 -0
- baselines/rum/models/bet/gpt.py +277 -0
- baselines/rum/models/bet/tokenized_bet.py +454 -0
- baselines/rum/models/bet/utils.py +124 -0
- baselines/rum/models/bet/vqbet.py +410 -0
- baselines/rum/models/bet/vqvae/__init__.py +3 -0
- baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
- baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
- baselines/rum/models/bet/vqvae/vqvae.py +313 -0
- baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
- baselines/rum/models/custom.py +33 -0
- baselines/rum/models/encoders/__init__.py +0 -0
- baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
- baselines/rum/models/encoders/identity.py +45 -0
- baselines/rum/models/encoders/timm_encoders.py +82 -0
- baselines/rum/models/policies/diffusion_policy.py +881 -0
- baselines/rum/models/policies/open_loop.py +122 -0
- baselines/rum/models/policies/simple_open_loop.py +108 -0
- baselines/rum/molmo/server.py +144 -0
- baselines/rum/policy.py +293 -0
- baselines/rum/utils/__init__.py +212 -0
- baselines/rum/utils/action_transforms.py +22 -0
- baselines/rum/utils/decord_transforms.py +135 -0
- baselines/rum/utils/rpc.py +249 -0
- baselines/rum/utils/schedulers.py +71 -0
- baselines/rum/utils/trajectory_vis.py +128 -0
- baselines/rum/utils/zmq_utils.py +281 -0
- baselines/rum_policy.py +108 -0
- egogym/__init__.py +8 -0
- egogym/assets/constants.py +1804 -0
- egogym/components/__init__.py +1 -0
- egogym/components/object.py +94 -0
- egogym/egogym.py +106 -0
- egogym/embodiments/__init__.py +10 -0
- egogym/embodiments/arms/__init__.py +4 -0
- egogym/embodiments/arms/arm.py +65 -0
- egogym/embodiments/arms/droid.py +49 -0
- egogym/embodiments/grippers/__init__.py +4 -0
- egogym/embodiments/grippers/floating_gripper.py +58 -0
- egogym/embodiments/grippers/rum.py +6 -0
- egogym/embodiments/robot.py +95 -0
- egogym/evaluate.py +216 -0
- egogym/managers/__init__.py +2 -0
- egogym/managers/objects_managers.py +30 -0
- egogym/managers/textures_manager.py +21 -0
- egogym/misc/molmo_client.py +49 -0
- egogym/misc/molmo_server.py +197 -0
- egogym/policies/__init__.py +1 -0
- egogym/policies/base_policy.py +13 -0
- egogym/scripts/analayze.py +834 -0
- egogym/scripts/plot.py +87 -0
- egogym/scripts/plot_correlation.py +392 -0
- egogym/scripts/plot_correlation_hardcoded.py +338 -0
- egogym/scripts/plot_failure.py +248 -0
- egogym/scripts/plot_failure_hardcoded.py +195 -0
- egogym/scripts/plot_failure_vlm.py +257 -0
- egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
- egogym/scripts/plot_line.py +303 -0
- egogym/scripts/plot_line_hardcoded.py +285 -0
- egogym/scripts/plot_pi0_bars.py +169 -0
- egogym/tasks/close.py +84 -0
- egogym/tasks/open.py +85 -0
- egogym/tasks/pick.py +121 -0
- egogym/utils.py +969 -0
- egogym/wrappers/__init__.py +20 -0
- egogym/wrappers/episode_monitor.py +282 -0
- egogym/wrappers/unprivileged_chatgpt.py +163 -0
- egogym/wrappers/unprivileged_gemini.py +157 -0
- egogym/wrappers/unprivileged_molmo.py +88 -0
- egogym/wrappers/unprivileged_moondream.py +121 -0
- egogym-0.1.0.dist-info/METADATA +52 -0
- egogym-0.1.0.dist-info/RECORD +83 -0
- egogym-0.1.0.dist-info/WHEEL +5 -0
- egogym-0.1.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,1194 @@
|
|
|
1
|
+
from functools import partial
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn, einsum
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
import torch.distributed as distributed
|
|
7
|
+
from torch.optim import Optimizer
|
|
8
|
+
from torch.cuda.amp import autocast
|
|
9
|
+
|
|
10
|
+
from einops import rearrange, repeat, reduce, pack, unpack
|
|
11
|
+
|
|
12
|
+
from typing import Callable
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def exists(val):
|
|
16
|
+
return val is not None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def default(val, d):
|
|
20
|
+
return val if exists(val) else d
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def noop(*args, **kwargs):
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def identity(t):
|
|
28
|
+
return t
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def l2norm(t):
|
|
32
|
+
return F.normalize(t, p=2, dim=-1)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def cdist(x, y):
|
|
36
|
+
x2 = reduce(x**2, "b n d -> b n", "sum")
|
|
37
|
+
y2 = reduce(y**2, "b n d -> b n", "sum")
|
|
38
|
+
xy = einsum("b i d, b j d -> b i j", x, y) * -2
|
|
39
|
+
return (rearrange(x2, "b i -> b i 1") + rearrange(y2, "b j -> b 1 j") + xy).sqrt()
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def log(t, eps=1e-20):
|
|
43
|
+
return torch.log(t.clamp(min=eps))
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def ema_inplace(old, new, decay):
|
|
47
|
+
is_mps = str(old.device).startswith("mps:")
|
|
48
|
+
|
|
49
|
+
if not is_mps:
|
|
50
|
+
old.lerp_(new, 1 - decay)
|
|
51
|
+
else:
|
|
52
|
+
old.mul_(decay).add_(new * (1 - decay))
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def pack_one(t, pattern):
|
|
56
|
+
return pack([t], pattern)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def unpack_one(t, ps, pattern):
|
|
60
|
+
return unpack(t, ps, pattern)[0]
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def uniform_init(*shape):
|
|
64
|
+
t = torch.empty(shape)
|
|
65
|
+
nn.init.kaiming_uniform_(t)
|
|
66
|
+
return t
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def gumbel_noise(t):
|
|
70
|
+
noise = torch.zeros_like(t).uniform_(0, 1)
|
|
71
|
+
return -log(-log(noise))
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def gumbel_sample(
|
|
75
|
+
logits,
|
|
76
|
+
temperature=1.0,
|
|
77
|
+
stochastic=False,
|
|
78
|
+
straight_through=False,
|
|
79
|
+
reinmax=False,
|
|
80
|
+
dim=-1,
|
|
81
|
+
training=True,
|
|
82
|
+
):
|
|
83
|
+
dtype, size = logits.dtype, logits.shape[dim]
|
|
84
|
+
|
|
85
|
+
if training and stochastic and temperature > 0:
|
|
86
|
+
sampling_logits = (logits / temperature) + gumbel_noise(logits)
|
|
87
|
+
else:
|
|
88
|
+
sampling_logits = logits
|
|
89
|
+
|
|
90
|
+
ind = sampling_logits.argmax(dim=dim)
|
|
91
|
+
one_hot = F.one_hot(ind, size).type(dtype)
|
|
92
|
+
|
|
93
|
+
assert not (
|
|
94
|
+
reinmax and not straight_through
|
|
95
|
+
), "reinmax can only be turned on if using straight through gumbel softmax"
|
|
96
|
+
|
|
97
|
+
if not straight_through or temperature <= 0.0 or not training:
|
|
98
|
+
return ind, one_hot
|
|
99
|
+
|
|
100
|
+
# use reinmax for better second-order accuracy - https://arxiv.org/abs/2304.08612
|
|
101
|
+
# algorithm 2
|
|
102
|
+
|
|
103
|
+
if reinmax:
|
|
104
|
+
π0 = logits.softmax(dim=dim)
|
|
105
|
+
π1 = (one_hot + (logits / temperature).softmax(dim=dim)) / 2
|
|
106
|
+
π1 = ((log(π1) - logits).detach() + logits).softmax(dim=1)
|
|
107
|
+
π2 = 2 * π1 - 0.5 * π0
|
|
108
|
+
one_hot = π2 - π2.detach() + one_hot
|
|
109
|
+
else:
|
|
110
|
+
π1 = (logits / temperature).softmax(dim=dim)
|
|
111
|
+
one_hot = one_hot + π1 - π1.detach()
|
|
112
|
+
|
|
113
|
+
return ind, one_hot
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def laplace_smoothing(x, n_categories, eps=1e-5, dim=-1):
|
|
117
|
+
denom = x.sum(dim=dim, keepdim=True)
|
|
118
|
+
return (x + eps) / (denom + n_categories * eps)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def sample_vectors(samples, num):
|
|
122
|
+
num_samples, device = samples.shape[0], samples.device
|
|
123
|
+
if num_samples >= num:
|
|
124
|
+
indices = torch.randperm(num_samples, device=device)[:num]
|
|
125
|
+
else:
|
|
126
|
+
indices = torch.randint(0, num_samples, (num,), device=device)
|
|
127
|
+
|
|
128
|
+
return samples[indices]
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def batched_sample_vectors(samples, num):
|
|
132
|
+
return torch.stack(
|
|
133
|
+
[sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def pad_shape(shape, size, dim=0):
|
|
138
|
+
return [size if i == dim else s for i, s in enumerate(shape)]
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def sample_multinomial(total_count, probs):
|
|
142
|
+
device = probs.device
|
|
143
|
+
probs = probs.cpu()
|
|
144
|
+
|
|
145
|
+
total_count = probs.new_full((), total_count)
|
|
146
|
+
remainder = probs.new_ones(())
|
|
147
|
+
sample = torch.empty_like(probs, dtype=torch.long)
|
|
148
|
+
|
|
149
|
+
for i, p in enumerate(probs):
|
|
150
|
+
s = torch.binomial(total_count, p / remainder)
|
|
151
|
+
sample[i] = s
|
|
152
|
+
total_count -= s
|
|
153
|
+
remainder -= p
|
|
154
|
+
|
|
155
|
+
return sample.to(device)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def all_gather_sizes(x, dim):
|
|
159
|
+
size = torch.tensor(x.shape[dim], dtype=torch.long, device=x.device)
|
|
160
|
+
all_sizes = [torch.empty_like(size) for _ in range(distributed.get_world_size())]
|
|
161
|
+
distributed.all_gather(all_sizes, size)
|
|
162
|
+
return torch.stack(all_sizes)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def all_gather_variably_sized(x, sizes, dim=0):
|
|
166
|
+
rank = distributed.get_rank()
|
|
167
|
+
all_x = []
|
|
168
|
+
|
|
169
|
+
for i, size in enumerate(sizes):
|
|
170
|
+
t = x if i == rank else x.new_empty(pad_shape(x.shape, size, dim))
|
|
171
|
+
distributed.broadcast(t, src=i, async_op=True)
|
|
172
|
+
all_x.append(t)
|
|
173
|
+
|
|
174
|
+
distributed.barrier()
|
|
175
|
+
return all_x
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def sample_vectors_distributed(local_samples, num):
|
|
179
|
+
local_samples = rearrange(local_samples, "1 ... -> ...")
|
|
180
|
+
|
|
181
|
+
rank = distributed.get_rank()
|
|
182
|
+
all_num_samples = all_gather_sizes(local_samples, dim=0)
|
|
183
|
+
|
|
184
|
+
if rank == 0:
|
|
185
|
+
samples_per_rank = sample_multinomial(
|
|
186
|
+
num, all_num_samples / all_num_samples.sum()
|
|
187
|
+
)
|
|
188
|
+
else:
|
|
189
|
+
samples_per_rank = torch.empty_like(all_num_samples)
|
|
190
|
+
|
|
191
|
+
distributed.broadcast(samples_per_rank, src=0)
|
|
192
|
+
samples_per_rank = samples_per_rank.tolist()
|
|
193
|
+
|
|
194
|
+
local_samples = sample_vectors(local_samples, samples_per_rank[rank])
|
|
195
|
+
all_samples = all_gather_variably_sized(local_samples, samples_per_rank, dim=0)
|
|
196
|
+
out = torch.cat(all_samples, dim=0)
|
|
197
|
+
|
|
198
|
+
return rearrange(out, "... -> 1 ...")
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def batched_bincount(x, *, minlength):
|
|
202
|
+
batch, dtype, device = x.shape[0], x.dtype, x.device
|
|
203
|
+
target = torch.zeros(batch, minlength, dtype=dtype, device=device)
|
|
204
|
+
values = torch.ones_like(x)
|
|
205
|
+
target.scatter_add_(-1, x, values)
|
|
206
|
+
return target
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def kmeans(
|
|
210
|
+
samples,
|
|
211
|
+
num_clusters,
|
|
212
|
+
num_iters=10,
|
|
213
|
+
use_cosine_sim=False,
|
|
214
|
+
sample_fn=batched_sample_vectors,
|
|
215
|
+
all_reduce_fn=noop,
|
|
216
|
+
):
|
|
217
|
+
num_codebooks, dim, dtype, device = (
|
|
218
|
+
samples.shape[0],
|
|
219
|
+
samples.shape[-1],
|
|
220
|
+
samples.dtype,
|
|
221
|
+
samples.device,
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
means = sample_fn(samples, num_clusters)
|
|
225
|
+
|
|
226
|
+
for _ in range(num_iters):
|
|
227
|
+
if use_cosine_sim:
|
|
228
|
+
dists = samples @ rearrange(means, "h n d -> h d n")
|
|
229
|
+
else:
|
|
230
|
+
dists = -torch.cdist(samples, means, p=2)
|
|
231
|
+
|
|
232
|
+
buckets = torch.argmax(dists, dim=-1)
|
|
233
|
+
bins = batched_bincount(buckets, minlength=num_clusters)
|
|
234
|
+
all_reduce_fn(bins)
|
|
235
|
+
|
|
236
|
+
zero_mask = bins == 0
|
|
237
|
+
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
|
238
|
+
|
|
239
|
+
new_means = buckets.new_zeros(num_codebooks, num_clusters, dim, dtype=dtype)
|
|
240
|
+
|
|
241
|
+
new_means.scatter_add_(1, repeat(buckets, "h n -> h n d", d=dim), samples)
|
|
242
|
+
new_means = new_means / rearrange(bins_min_clamped, "... -> ... 1")
|
|
243
|
+
all_reduce_fn(new_means)
|
|
244
|
+
|
|
245
|
+
if use_cosine_sim:
|
|
246
|
+
new_means = l2norm(new_means)
|
|
247
|
+
|
|
248
|
+
means = torch.where(rearrange(zero_mask, "... -> ... 1"), means, new_means)
|
|
249
|
+
|
|
250
|
+
return means, bins
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def batched_embedding(indices, embeds):
|
|
254
|
+
batch, dim = indices.shape[1], embeds.shape[-1]
|
|
255
|
+
indices = repeat(indices, "h b n -> h b n d", d=dim)
|
|
256
|
+
embeds = repeat(embeds, "h c d -> h b c d", b=batch)
|
|
257
|
+
return embeds.gather(2, indices)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
# regularization losses
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def orthogonal_loss_fn(t):
|
|
264
|
+
# eq (2) from https://arxiv.org/abs/2112.00384
|
|
265
|
+
h, n = t.shape[:2]
|
|
266
|
+
normed_codes = l2norm(t)
|
|
267
|
+
cosine_sim = einsum("h i d, h j d -> h i j", normed_codes, normed_codes)
|
|
268
|
+
return (cosine_sim**2).sum() / (h * n**2) - (1 / n)
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
# distance types
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class EuclideanCodebook(nn.Module):
|
|
275
|
+
def __init__(
|
|
276
|
+
self,
|
|
277
|
+
dim,
|
|
278
|
+
codebook_size,
|
|
279
|
+
num_codebooks=1,
|
|
280
|
+
kmeans_init=False,
|
|
281
|
+
kmeans_iters=10,
|
|
282
|
+
sync_kmeans=True,
|
|
283
|
+
decay=0.8,
|
|
284
|
+
eps=1e-5,
|
|
285
|
+
threshold_ema_dead_code=2,
|
|
286
|
+
reset_cluster_size=None,
|
|
287
|
+
use_ddp=False,
|
|
288
|
+
learnable_codebook=False,
|
|
289
|
+
gumbel_sample=gumbel_sample,
|
|
290
|
+
sample_codebook_temp=1.0,
|
|
291
|
+
ema_update=True,
|
|
292
|
+
affine_param=False,
|
|
293
|
+
sync_affine_param=False,
|
|
294
|
+
affine_param_batch_decay=0.99,
|
|
295
|
+
affine_param_codebook_decay=0.9,
|
|
296
|
+
):
|
|
297
|
+
super().__init__()
|
|
298
|
+
self.transform_input = identity
|
|
299
|
+
|
|
300
|
+
self.decay = decay
|
|
301
|
+
self.ema_update = ema_update
|
|
302
|
+
|
|
303
|
+
init_fn = uniform_init if not kmeans_init else torch.zeros
|
|
304
|
+
embed = init_fn(num_codebooks, codebook_size, dim)
|
|
305
|
+
|
|
306
|
+
self.codebook_size = codebook_size
|
|
307
|
+
self.num_codebooks = num_codebooks
|
|
308
|
+
|
|
309
|
+
self.kmeans_iters = kmeans_iters
|
|
310
|
+
self.eps = eps
|
|
311
|
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
|
312
|
+
self.reset_cluster_size = default(reset_cluster_size, threshold_ema_dead_code)
|
|
313
|
+
|
|
314
|
+
assert callable(gumbel_sample)
|
|
315
|
+
self.gumbel_sample = gumbel_sample
|
|
316
|
+
self.sample_codebook_temp = sample_codebook_temp
|
|
317
|
+
|
|
318
|
+
assert not (
|
|
319
|
+
use_ddp and num_codebooks > 1 and kmeans_init
|
|
320
|
+
), "kmeans init is not compatible with multiple codebooks in distributed environment for now"
|
|
321
|
+
|
|
322
|
+
self.sample_fn = (
|
|
323
|
+
sample_vectors_distributed
|
|
324
|
+
if use_ddp and sync_kmeans
|
|
325
|
+
else batched_sample_vectors
|
|
326
|
+
)
|
|
327
|
+
self.kmeans_all_reduce_fn = (
|
|
328
|
+
distributed.all_reduce if use_ddp and sync_kmeans else noop
|
|
329
|
+
)
|
|
330
|
+
self.all_reduce_fn = distributed.all_reduce if use_ddp else noop
|
|
331
|
+
|
|
332
|
+
self.register_buffer("initted", torch.Tensor([not kmeans_init]))
|
|
333
|
+
self.register_buffer("cluster_size", torch.zeros(num_codebooks, codebook_size))
|
|
334
|
+
self.register_buffer("embed_avg", embed.clone())
|
|
335
|
+
|
|
336
|
+
self.learnable_codebook = learnable_codebook
|
|
337
|
+
if learnable_codebook:
|
|
338
|
+
self.embed = nn.Parameter(embed)
|
|
339
|
+
else:
|
|
340
|
+
self.register_buffer("embed", embed)
|
|
341
|
+
|
|
342
|
+
# affine related params
|
|
343
|
+
|
|
344
|
+
self.affine_param = affine_param
|
|
345
|
+
self.sync_affine_param = sync_affine_param
|
|
346
|
+
|
|
347
|
+
if not affine_param:
|
|
348
|
+
return
|
|
349
|
+
|
|
350
|
+
self.affine_param_batch_decay = affine_param_batch_decay
|
|
351
|
+
self.affine_param_codebook_decay = affine_param_codebook_decay
|
|
352
|
+
|
|
353
|
+
self.register_buffer("batch_mean", None)
|
|
354
|
+
self.register_buffer("batch_variance", None)
|
|
355
|
+
|
|
356
|
+
self.register_buffer("codebook_mean_needs_init", torch.Tensor([True]))
|
|
357
|
+
self.register_buffer("codebook_mean", torch.empty(num_codebooks, 1, dim))
|
|
358
|
+
self.register_buffer("codebook_variance_needs_init", torch.Tensor([True]))
|
|
359
|
+
self.register_buffer("codebook_variance", torch.empty(num_codebooks, 1, dim))
|
|
360
|
+
|
|
361
|
+
@torch.jit.ignore
|
|
362
|
+
def init_embed_(self, data, mask=None):
|
|
363
|
+
if self.initted:
|
|
364
|
+
return
|
|
365
|
+
|
|
366
|
+
if exists(mask):
|
|
367
|
+
c = data.shape[0]
|
|
368
|
+
data = rearrange(data[mask], "(c n) d -> c n d", c=c)
|
|
369
|
+
|
|
370
|
+
embed, cluster_size = kmeans(
|
|
371
|
+
data,
|
|
372
|
+
self.codebook_size,
|
|
373
|
+
self.kmeans_iters,
|
|
374
|
+
sample_fn=self.sample_fn,
|
|
375
|
+
all_reduce_fn=self.kmeans_all_reduce_fn,
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
embed_sum = embed * rearrange(cluster_size, "... -> ... 1")
|
|
379
|
+
|
|
380
|
+
self.embed.data.copy_(embed)
|
|
381
|
+
self.embed_avg.data.copy_(embed_sum)
|
|
382
|
+
self.cluster_size.data.copy_(cluster_size)
|
|
383
|
+
self.initted.data.copy_(torch.Tensor([True]))
|
|
384
|
+
|
|
385
|
+
@torch.jit.ignore
|
|
386
|
+
def update_with_decay(self, buffer_name, new_value, decay):
|
|
387
|
+
old_value = getattr(self, buffer_name)
|
|
388
|
+
|
|
389
|
+
needs_init = getattr(self, buffer_name + "_needs_init", False)
|
|
390
|
+
|
|
391
|
+
if needs_init:
|
|
392
|
+
self.register_buffer(buffer_name + "_needs_init", torch.Tensor([False]))
|
|
393
|
+
|
|
394
|
+
if not exists(old_value) or needs_init:
|
|
395
|
+
self.register_buffer(buffer_name, new_value.detach())
|
|
396
|
+
|
|
397
|
+
return
|
|
398
|
+
|
|
399
|
+
value = old_value * decay + new_value.detach() * (1 - decay)
|
|
400
|
+
self.register_buffer(buffer_name, value)
|
|
401
|
+
|
|
402
|
+
@torch.jit.ignore
|
|
403
|
+
def update_affine(self, data, embed, mask=None):
|
|
404
|
+
assert self.affine_param
|
|
405
|
+
|
|
406
|
+
var_fn = partial(torch.var, unbiased=False)
|
|
407
|
+
|
|
408
|
+
# calculate codebook mean and variance
|
|
409
|
+
|
|
410
|
+
embed = rearrange(embed, "h ... d -> h (...) d")
|
|
411
|
+
|
|
412
|
+
if self.training:
|
|
413
|
+
self.update_with_decay(
|
|
414
|
+
"codebook_mean",
|
|
415
|
+
reduce(embed, "h n d -> h 1 d", "mean"),
|
|
416
|
+
self.affine_param_codebook_decay,
|
|
417
|
+
)
|
|
418
|
+
self.update_with_decay(
|
|
419
|
+
"codebook_variance",
|
|
420
|
+
reduce(embed, "h n d -> h 1 d", var_fn),
|
|
421
|
+
self.affine_param_codebook_decay,
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
# prepare batch data, which depends on whether it has masking
|
|
425
|
+
|
|
426
|
+
data = rearrange(data, "h ... d -> h (...) d")
|
|
427
|
+
|
|
428
|
+
if exists(mask):
|
|
429
|
+
c = data.shape[0]
|
|
430
|
+
data = rearrange(data[mask], "(c n) d -> c n d", c=c)
|
|
431
|
+
|
|
432
|
+
# calculate batch mean and variance
|
|
433
|
+
|
|
434
|
+
if not self.sync_affine_param:
|
|
435
|
+
self.update_with_decay(
|
|
436
|
+
"batch_mean",
|
|
437
|
+
reduce(data, "h n d -> h 1 d", "mean"),
|
|
438
|
+
self.affine_param_batch_decay,
|
|
439
|
+
)
|
|
440
|
+
self.update_with_decay(
|
|
441
|
+
"batch_variance",
|
|
442
|
+
reduce(data, "h n d -> h 1 d", var_fn),
|
|
443
|
+
self.affine_param_batch_decay,
|
|
444
|
+
)
|
|
445
|
+
return
|
|
446
|
+
|
|
447
|
+
num_vectors, device, dtype = data.shape[-2], data.device, data.dtype
|
|
448
|
+
|
|
449
|
+
# number of vectors, for denominator
|
|
450
|
+
|
|
451
|
+
num_vectors = torch.tensor([num_vectors], device=device, dtype=dtype)
|
|
452
|
+
distributed.all_reduce(num_vectors)
|
|
453
|
+
|
|
454
|
+
# calculate distributed mean
|
|
455
|
+
|
|
456
|
+
batch_sum = reduce(data, "h n d -> h 1 d", "sum")
|
|
457
|
+
distributed.all_reduce(batch_sum)
|
|
458
|
+
batch_mean = batch_sum / num_vectors
|
|
459
|
+
|
|
460
|
+
self.update_with_decay("batch_mean", batch_mean, self.affine_param_batch_decay)
|
|
461
|
+
|
|
462
|
+
# calculate distributed variance
|
|
463
|
+
|
|
464
|
+
variance_numer = reduce((data - batch_mean) ** 2, "h n d -> h 1 d", "sum")
|
|
465
|
+
distributed.all_reduce(variance_numer)
|
|
466
|
+
batch_variance = variance_numer / num_vectors
|
|
467
|
+
|
|
468
|
+
self.update_with_decay(
|
|
469
|
+
"batch_variance", batch_variance, self.affine_param_batch_decay
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
def replace(self, batch_samples, batch_mask):
|
|
473
|
+
for ind, (samples, mask) in enumerate(
|
|
474
|
+
zip(batch_samples.unbind(dim=0), batch_mask.unbind(dim=0))
|
|
475
|
+
):
|
|
476
|
+
if not torch.any(mask):
|
|
477
|
+
continue
|
|
478
|
+
|
|
479
|
+
sampled = self.sample_fn(
|
|
480
|
+
rearrange(samples, "... -> 1 ..."), mask.sum().item()
|
|
481
|
+
)
|
|
482
|
+
sampled = rearrange(sampled, "1 ... -> ...")
|
|
483
|
+
|
|
484
|
+
self.embed.data[ind][mask] = sampled
|
|
485
|
+
|
|
486
|
+
self.cluster_size.data[ind][mask] = self.reset_cluster_size
|
|
487
|
+
self.embed_avg.data[ind][mask] = sampled * self.reset_cluster_size
|
|
488
|
+
|
|
489
|
+
def expire_codes_(self, batch_samples):
|
|
490
|
+
if self.threshold_ema_dead_code == 0:
|
|
491
|
+
return
|
|
492
|
+
|
|
493
|
+
expired_codes = self.cluster_size < self.threshold_ema_dead_code
|
|
494
|
+
|
|
495
|
+
if not torch.any(expired_codes):
|
|
496
|
+
return
|
|
497
|
+
|
|
498
|
+
batch_samples = rearrange(batch_samples, "h ... d -> h (...) d")
|
|
499
|
+
self.replace(batch_samples, batch_mask=expired_codes)
|
|
500
|
+
|
|
501
|
+
@autocast(enabled=False)
|
|
502
|
+
def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False):
|
|
503
|
+
needs_codebook_dim = x.ndim < 4
|
|
504
|
+
sample_codebook_temp = default(sample_codebook_temp, self.sample_codebook_temp)
|
|
505
|
+
|
|
506
|
+
x = x.float()
|
|
507
|
+
|
|
508
|
+
if needs_codebook_dim:
|
|
509
|
+
x = rearrange(x, "... -> 1 ...")
|
|
510
|
+
|
|
511
|
+
dtype = x.dtype
|
|
512
|
+
flatten, ps = pack_one(x, "h * d")
|
|
513
|
+
|
|
514
|
+
if exists(mask):
|
|
515
|
+
mask = repeat(
|
|
516
|
+
mask,
|
|
517
|
+
"b n -> c (b h n)",
|
|
518
|
+
c=flatten.shape[0],
|
|
519
|
+
h=flatten.shape[-2] // (mask.shape[0] * mask.shape[1]),
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
self.init_embed_(flatten, mask=mask)
|
|
523
|
+
|
|
524
|
+
if self.affine_param:
|
|
525
|
+
self.update_affine(flatten, self.embed, mask=mask)
|
|
526
|
+
|
|
527
|
+
embed = self.embed if self.learnable_codebook else self.embed.detach()
|
|
528
|
+
|
|
529
|
+
if self.affine_param:
|
|
530
|
+
codebook_std = self.codebook_variance.clamp(min=1e-5).sqrt()
|
|
531
|
+
batch_std = self.batch_variance.clamp(min=1e-5).sqrt()
|
|
532
|
+
embed = (embed - self.codebook_mean) * (
|
|
533
|
+
batch_std / codebook_std
|
|
534
|
+
) + self.batch_mean
|
|
535
|
+
|
|
536
|
+
dist = -cdist(flatten, embed)
|
|
537
|
+
|
|
538
|
+
embed_ind, embed_onehot = self.gumbel_sample(
|
|
539
|
+
dist, dim=-1, temperature=sample_codebook_temp, training=self.training
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
embed_ind = unpack_one(embed_ind, ps, "h *")
|
|
543
|
+
|
|
544
|
+
if self.training:
|
|
545
|
+
unpacked_onehot = unpack_one(embed_onehot, ps, "h * c")
|
|
546
|
+
quantize = einsum("h b n c, h c d -> h b n d", unpacked_onehot, embed)
|
|
547
|
+
else:
|
|
548
|
+
quantize = batched_embedding(embed_ind, embed)
|
|
549
|
+
|
|
550
|
+
if self.training and self.ema_update and not freeze_codebook:
|
|
551
|
+
if self.affine_param:
|
|
552
|
+
flatten = (flatten - self.batch_mean) * (
|
|
553
|
+
codebook_std / batch_std
|
|
554
|
+
) + self.codebook_mean
|
|
555
|
+
|
|
556
|
+
if exists(mask):
|
|
557
|
+
embed_onehot[~mask] = 0.0
|
|
558
|
+
|
|
559
|
+
cluster_size = embed_onehot.sum(dim=1)
|
|
560
|
+
|
|
561
|
+
self.all_reduce_fn(cluster_size)
|
|
562
|
+
ema_inplace(self.cluster_size.data, cluster_size, self.decay)
|
|
563
|
+
|
|
564
|
+
embed_sum = einsum("h n d, h n c -> h c d", flatten, embed_onehot)
|
|
565
|
+
self.all_reduce_fn(embed_sum.contiguous())
|
|
566
|
+
ema_inplace(self.embed_avg.data, embed_sum, self.decay)
|
|
567
|
+
|
|
568
|
+
cluster_size = laplace_smoothing(
|
|
569
|
+
self.cluster_size, self.codebook_size, self.eps
|
|
570
|
+
) * self.cluster_size.sum(dim=-1, keepdim=True)
|
|
571
|
+
|
|
572
|
+
embed_normalized = self.embed_avg / rearrange(cluster_size, "... -> ... 1")
|
|
573
|
+
self.embed.data.copy_(embed_normalized)
|
|
574
|
+
self.expire_codes_(x)
|
|
575
|
+
|
|
576
|
+
if needs_codebook_dim:
|
|
577
|
+
quantize, embed_ind = map(
|
|
578
|
+
lambda t: rearrange(t, "1 ... -> ..."), (quantize, embed_ind)
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
dist = unpack_one(dist, ps, "h * d")
|
|
582
|
+
|
|
583
|
+
return quantize, embed_ind, dist
|
|
584
|
+
|
|
585
|
+
|
|
586
|
+
class CosineSimCodebook(nn.Module):
|
|
587
|
+
def __init__(
|
|
588
|
+
self,
|
|
589
|
+
dim,
|
|
590
|
+
codebook_size,
|
|
591
|
+
num_codebooks=1,
|
|
592
|
+
kmeans_init=False,
|
|
593
|
+
kmeans_iters=10,
|
|
594
|
+
sync_kmeans=True,
|
|
595
|
+
decay=0.8,
|
|
596
|
+
eps=1e-5,
|
|
597
|
+
threshold_ema_dead_code=2,
|
|
598
|
+
reset_cluster_size=None,
|
|
599
|
+
use_ddp=False,
|
|
600
|
+
learnable_codebook=False,
|
|
601
|
+
gumbel_sample=gumbel_sample,
|
|
602
|
+
sample_codebook_temp=1.0,
|
|
603
|
+
ema_update=True,
|
|
604
|
+
):
|
|
605
|
+
super().__init__()
|
|
606
|
+
self.transform_input = l2norm
|
|
607
|
+
|
|
608
|
+
self.ema_update = ema_update
|
|
609
|
+
self.decay = decay
|
|
610
|
+
|
|
611
|
+
if not kmeans_init:
|
|
612
|
+
embed = l2norm(uniform_init(num_codebooks, codebook_size, dim))
|
|
613
|
+
else:
|
|
614
|
+
embed = torch.zeros(num_codebooks, codebook_size, dim)
|
|
615
|
+
|
|
616
|
+
self.codebook_size = codebook_size
|
|
617
|
+
self.num_codebooks = num_codebooks
|
|
618
|
+
|
|
619
|
+
self.kmeans_iters = kmeans_iters
|
|
620
|
+
self.eps = eps
|
|
621
|
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
|
622
|
+
self.reset_cluster_size = default(reset_cluster_size, threshold_ema_dead_code)
|
|
623
|
+
|
|
624
|
+
assert callable(gumbel_sample)
|
|
625
|
+
self.gumbel_sample = gumbel_sample
|
|
626
|
+
self.sample_codebook_temp = sample_codebook_temp
|
|
627
|
+
|
|
628
|
+
self.sample_fn = (
|
|
629
|
+
sample_vectors_distributed
|
|
630
|
+
if use_ddp and sync_kmeans
|
|
631
|
+
else batched_sample_vectors
|
|
632
|
+
)
|
|
633
|
+
self.kmeans_all_reduce_fn = (
|
|
634
|
+
distributed.all_reduce if use_ddp and sync_kmeans else noop
|
|
635
|
+
)
|
|
636
|
+
self.all_reduce_fn = distributed.all_reduce if use_ddp else noop
|
|
637
|
+
|
|
638
|
+
self.register_buffer("initted", torch.Tensor([not kmeans_init]))
|
|
639
|
+
self.register_buffer("cluster_size", torch.zeros(num_codebooks, codebook_size))
|
|
640
|
+
self.register_buffer("embed_avg", embed.clone())
|
|
641
|
+
|
|
642
|
+
self.learnable_codebook = learnable_codebook
|
|
643
|
+
if learnable_codebook:
|
|
644
|
+
self.embed = nn.Parameter(embed)
|
|
645
|
+
else:
|
|
646
|
+
self.register_buffer("embed", embed)
|
|
647
|
+
|
|
648
|
+
@torch.jit.ignore
|
|
649
|
+
def init_embed_(self, data, mask=None):
|
|
650
|
+
if self.initted:
|
|
651
|
+
return
|
|
652
|
+
|
|
653
|
+
if exists(mask):
|
|
654
|
+
c = data.shape[0]
|
|
655
|
+
data = rearrange(data[mask], "(c n) d -> c n d", c=c)
|
|
656
|
+
|
|
657
|
+
embed, cluster_size = kmeans(
|
|
658
|
+
data,
|
|
659
|
+
self.codebook_size,
|
|
660
|
+
self.kmeans_iters,
|
|
661
|
+
use_cosine_sim=True,
|
|
662
|
+
sample_fn=self.sample_fn,
|
|
663
|
+
all_reduce_fn=self.kmeans_all_reduce_fn,
|
|
664
|
+
)
|
|
665
|
+
|
|
666
|
+
embed_sum = embed * rearrange(cluster_size, "... -> ... 1")
|
|
667
|
+
|
|
668
|
+
self.embed.data.copy_(embed)
|
|
669
|
+
self.embed_avg.data.copy_(embed_sum)
|
|
670
|
+
self.cluster_size.data.copy_(cluster_size)
|
|
671
|
+
self.initted.data.copy_(torch.Tensor([True]))
|
|
672
|
+
|
|
673
|
+
def replace(self, batch_samples, batch_mask):
|
|
674
|
+
batch_samples = l2norm(batch_samples)
|
|
675
|
+
|
|
676
|
+
for ind, (samples, mask) in enumerate(
|
|
677
|
+
zip(batch_samples.unbind(dim=0), batch_mask.unbind(dim=0))
|
|
678
|
+
):
|
|
679
|
+
if not torch.any(mask):
|
|
680
|
+
continue
|
|
681
|
+
|
|
682
|
+
sampled = self.sample_fn(
|
|
683
|
+
rearrange(samples, "... -> 1 ..."), mask.sum().item()
|
|
684
|
+
)
|
|
685
|
+
sampled = rearrange(sampled, "1 ... -> ...")
|
|
686
|
+
|
|
687
|
+
self.embed.data[ind][mask] = sampled
|
|
688
|
+
self.embed_avg.data[ind][mask] = sampled * self.reset_cluster_size
|
|
689
|
+
self.cluster_size.data[ind][mask] = self.reset_cluster_size
|
|
690
|
+
|
|
691
|
+
def expire_codes_(self, batch_samples):
|
|
692
|
+
if self.threshold_ema_dead_code == 0:
|
|
693
|
+
return
|
|
694
|
+
|
|
695
|
+
expired_codes = self.cluster_size < self.threshold_ema_dead_code
|
|
696
|
+
|
|
697
|
+
if not torch.any(expired_codes):
|
|
698
|
+
return
|
|
699
|
+
|
|
700
|
+
batch_samples = rearrange(batch_samples, "h ... d -> h (...) d")
|
|
701
|
+
self.replace(batch_samples, batch_mask=expired_codes)
|
|
702
|
+
|
|
703
|
+
@autocast(enabled=False)
|
|
704
|
+
def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False):
|
|
705
|
+
needs_codebook_dim = x.ndim < 4
|
|
706
|
+
sample_codebook_temp = default(sample_codebook_temp, self.sample_codebook_temp)
|
|
707
|
+
|
|
708
|
+
x = x.float()
|
|
709
|
+
|
|
710
|
+
if needs_codebook_dim:
|
|
711
|
+
x = rearrange(x, "... -> 1 ...")
|
|
712
|
+
|
|
713
|
+
dtype = x.dtype
|
|
714
|
+
|
|
715
|
+
flatten, ps = pack_one(x, "h * d")
|
|
716
|
+
|
|
717
|
+
if exists(mask):
|
|
718
|
+
mask = repeat(
|
|
719
|
+
mask,
|
|
720
|
+
"b n -> c (b h n)",
|
|
721
|
+
c=flatten.shape[0],
|
|
722
|
+
h=flatten.shape[-2] // (mask.shape[0] * mask.shape[1]),
|
|
723
|
+
)
|
|
724
|
+
|
|
725
|
+
self.init_embed_(flatten, mask=mask)
|
|
726
|
+
|
|
727
|
+
embed = self.embed if self.learnable_codebook else self.embed.detach()
|
|
728
|
+
|
|
729
|
+
dist = einsum("h n d, h c d -> h n c", flatten, embed)
|
|
730
|
+
|
|
731
|
+
embed_ind, embed_onehot = self.gumbel_sample(
|
|
732
|
+
dist, dim=-1, temperature=sample_codebook_temp, training=self.training
|
|
733
|
+
)
|
|
734
|
+
embed_ind = unpack_one(embed_ind, ps, "h *")
|
|
735
|
+
|
|
736
|
+
if self.training:
|
|
737
|
+
unpacked_onehot = unpack_one(embed_onehot, ps, "h * c")
|
|
738
|
+
quantize = einsum("h b n c, h c d -> h b n d", unpacked_onehot, embed)
|
|
739
|
+
else:
|
|
740
|
+
quantize = batched_embedding(embed_ind, embed)
|
|
741
|
+
|
|
742
|
+
if self.training and self.ema_update and not freeze_codebook:
|
|
743
|
+
if exists(mask):
|
|
744
|
+
embed_onehot[~mask] = 0.0
|
|
745
|
+
|
|
746
|
+
bins = embed_onehot.sum(dim=1)
|
|
747
|
+
self.all_reduce_fn(bins)
|
|
748
|
+
|
|
749
|
+
ema_inplace(self.cluster_size.data, bins, self.decay)
|
|
750
|
+
|
|
751
|
+
embed_sum = einsum("h n d, h n c -> h c d", flatten, embed_onehot)
|
|
752
|
+
self.all_reduce_fn(embed_sum.contiguous())
|
|
753
|
+
ema_inplace(self.embed_avg.data, embed_sum, self.decay)
|
|
754
|
+
|
|
755
|
+
cluster_size = laplace_smoothing(
|
|
756
|
+
self.cluster_size, self.codebook_size, self.eps
|
|
757
|
+
) * self.cluster_size.sum(dim=-1, keepdim=True)
|
|
758
|
+
|
|
759
|
+
embed_normalized = self.embed_avg / rearrange(cluster_size, "... -> ... 1")
|
|
760
|
+
embed_normalized = l2norm(embed_normalized)
|
|
761
|
+
|
|
762
|
+
self.embed.data.copy_(l2norm(embed_normalized))
|
|
763
|
+
self.expire_codes_(x)
|
|
764
|
+
|
|
765
|
+
if needs_codebook_dim:
|
|
766
|
+
quantize, embed_ind = map(
|
|
767
|
+
lambda t: rearrange(t, "1 ... -> ..."), (quantize, embed_ind)
|
|
768
|
+
)
|
|
769
|
+
|
|
770
|
+
dist = unpack_one(dist, ps, "h * d")
|
|
771
|
+
return quantize, embed_ind, dist
|
|
772
|
+
|
|
773
|
+
|
|
774
|
+
# main class
|
|
775
|
+
|
|
776
|
+
|
|
777
|
+
class VectorQuantize(nn.Module):
|
|
778
|
+
def __init__(
|
|
779
|
+
self,
|
|
780
|
+
dim,
|
|
781
|
+
codebook_size,
|
|
782
|
+
codebook_dim=None,
|
|
783
|
+
heads=1,
|
|
784
|
+
separate_codebook_per_head=False,
|
|
785
|
+
decay=0.8,
|
|
786
|
+
eps=1e-5,
|
|
787
|
+
freeze_codebook=False,
|
|
788
|
+
kmeans_init=True,
|
|
789
|
+
kmeans_iters=10,
|
|
790
|
+
sync_kmeans=True,
|
|
791
|
+
use_cosine_sim=False,
|
|
792
|
+
threshold_ema_dead_code=2,
|
|
793
|
+
channel_last=True,
|
|
794
|
+
accept_image_fmap=False,
|
|
795
|
+
commitment_weight=1.0,
|
|
796
|
+
commitment_use_cross_entropy_loss=False,
|
|
797
|
+
orthogonal_reg_weight=0.0,
|
|
798
|
+
orthogonal_reg_active_codes_only=False,
|
|
799
|
+
orthogonal_reg_max_codes=None,
|
|
800
|
+
stochastic_sample_codes=False,
|
|
801
|
+
sample_codebook_temp=1.0,
|
|
802
|
+
straight_through=False,
|
|
803
|
+
reinmax=False, # using reinmax for improved straight-through, assuming straight through helps at all
|
|
804
|
+
sync_codebook=None,
|
|
805
|
+
sync_affine_param=False,
|
|
806
|
+
ema_update=True,
|
|
807
|
+
learnable_codebook=False,
|
|
808
|
+
in_place_codebook_optimizer: Callable[
|
|
809
|
+
..., Optimizer
|
|
810
|
+
] = None, # Optimizer used to update the codebook embedding if using learnable_codebook
|
|
811
|
+
affine_param=False,
|
|
812
|
+
affine_param_batch_decay=0.99,
|
|
813
|
+
affine_param_codebook_decay=0.9,
|
|
814
|
+
sync_update_v=0.0, # the v that controls optimistic vs pessimistic update for synchronous update rule (21) https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf
|
|
815
|
+
eval=False,
|
|
816
|
+
):
|
|
817
|
+
super().__init__()
|
|
818
|
+
if eval:
|
|
819
|
+
kmeans_init = False
|
|
820
|
+
ema_update = False
|
|
821
|
+
self.dim = dim
|
|
822
|
+
self.heads = heads
|
|
823
|
+
self.separate_codebook_per_head = separate_codebook_per_head
|
|
824
|
+
|
|
825
|
+
codebook_dim = default(codebook_dim, dim)
|
|
826
|
+
codebook_input_dim = codebook_dim * heads
|
|
827
|
+
|
|
828
|
+
requires_projection = codebook_input_dim != dim
|
|
829
|
+
self.project_in = (
|
|
830
|
+
nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
|
|
831
|
+
)
|
|
832
|
+
self.project_out = (
|
|
833
|
+
nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
|
|
834
|
+
)
|
|
835
|
+
|
|
836
|
+
self.eps = eps
|
|
837
|
+
self.commitment_weight = commitment_weight
|
|
838
|
+
self.commitment_use_cross_entropy_loss = commitment_use_cross_entropy_loss # whether to use cross entropy loss to codebook as commitment loss
|
|
839
|
+
|
|
840
|
+
self.learnable_codebook = learnable_codebook
|
|
841
|
+
|
|
842
|
+
has_codebook_orthogonal_loss = orthogonal_reg_weight > 0
|
|
843
|
+
self.has_codebook_orthogonal_loss = has_codebook_orthogonal_loss
|
|
844
|
+
self.orthogonal_reg_weight = orthogonal_reg_weight
|
|
845
|
+
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
|
|
846
|
+
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
|
|
847
|
+
|
|
848
|
+
assert not (
|
|
849
|
+
ema_update and learnable_codebook
|
|
850
|
+
), "learnable codebook not compatible with EMA update"
|
|
851
|
+
|
|
852
|
+
assert 0 <= sync_update_v <= 1.0
|
|
853
|
+
assert not (
|
|
854
|
+
sync_update_v > 0.0 and not learnable_codebook
|
|
855
|
+
), "learnable codebook must be turned on"
|
|
856
|
+
|
|
857
|
+
self.sync_update_v = sync_update_v
|
|
858
|
+
|
|
859
|
+
codebook_class = EuclideanCodebook if not use_cosine_sim else CosineSimCodebook
|
|
860
|
+
|
|
861
|
+
gumbel_sample_fn = partial(
|
|
862
|
+
gumbel_sample,
|
|
863
|
+
stochastic=stochastic_sample_codes,
|
|
864
|
+
reinmax=reinmax,
|
|
865
|
+
straight_through=straight_through,
|
|
866
|
+
)
|
|
867
|
+
|
|
868
|
+
if not exists(sync_codebook):
|
|
869
|
+
sync_codebook = (
|
|
870
|
+
distributed.is_initialized() and distributed.get_world_size() > 1
|
|
871
|
+
)
|
|
872
|
+
|
|
873
|
+
codebook_kwargs = dict(
|
|
874
|
+
dim=codebook_dim,
|
|
875
|
+
num_codebooks=heads if separate_codebook_per_head else 1,
|
|
876
|
+
codebook_size=codebook_size,
|
|
877
|
+
kmeans_init=kmeans_init,
|
|
878
|
+
kmeans_iters=kmeans_iters,
|
|
879
|
+
sync_kmeans=sync_kmeans,
|
|
880
|
+
decay=decay,
|
|
881
|
+
eps=eps,
|
|
882
|
+
threshold_ema_dead_code=threshold_ema_dead_code,
|
|
883
|
+
use_ddp=sync_codebook,
|
|
884
|
+
learnable_codebook=has_codebook_orthogonal_loss or learnable_codebook,
|
|
885
|
+
sample_codebook_temp=sample_codebook_temp,
|
|
886
|
+
gumbel_sample=gumbel_sample_fn,
|
|
887
|
+
ema_update=ema_update,
|
|
888
|
+
)
|
|
889
|
+
|
|
890
|
+
if affine_param:
|
|
891
|
+
assert (
|
|
892
|
+
not use_cosine_sim
|
|
893
|
+
), "affine param is only compatible with euclidean codebook"
|
|
894
|
+
codebook_kwargs = dict(
|
|
895
|
+
**codebook_kwargs,
|
|
896
|
+
affine_param=True,
|
|
897
|
+
sync_affine_param=sync_affine_param,
|
|
898
|
+
affine_param_batch_decay=affine_param_batch_decay,
|
|
899
|
+
affine_param_codebook_decay=affine_param_codebook_decay,
|
|
900
|
+
)
|
|
901
|
+
|
|
902
|
+
self._codebook = codebook_class(**codebook_kwargs)
|
|
903
|
+
if eval:
|
|
904
|
+
self._codebook.eval()
|
|
905
|
+
|
|
906
|
+
self.in_place_codebook_optimizer = (
|
|
907
|
+
in_place_codebook_optimizer(self._codebook.parameters())
|
|
908
|
+
if exists(in_place_codebook_optimizer)
|
|
909
|
+
else None
|
|
910
|
+
)
|
|
911
|
+
|
|
912
|
+
self.codebook_size = codebook_size
|
|
913
|
+
|
|
914
|
+
self.accept_image_fmap = accept_image_fmap
|
|
915
|
+
self.channel_last = channel_last
|
|
916
|
+
|
|
917
|
+
@property
|
|
918
|
+
def codebook(self):
|
|
919
|
+
codebook = self._codebook.embed
|
|
920
|
+
|
|
921
|
+
if self.separate_codebook_per_head:
|
|
922
|
+
return codebook
|
|
923
|
+
|
|
924
|
+
return rearrange(codebook, "1 ... -> ...")
|
|
925
|
+
|
|
926
|
+
@codebook.setter
|
|
927
|
+
def codebook(self, codes):
|
|
928
|
+
if not self.separate_codebook_per_head:
|
|
929
|
+
codes = rearrange(codes, "... -> 1 ...")
|
|
930
|
+
|
|
931
|
+
self._codebook.embed.copy_(codes)
|
|
932
|
+
|
|
933
|
+
def get_codes_from_indices(self, indices):
|
|
934
|
+
codebook = self.codebook
|
|
935
|
+
is_multiheaded = codebook.ndim > 2
|
|
936
|
+
|
|
937
|
+
if not is_multiheaded:
|
|
938
|
+
codes = codebook[indices]
|
|
939
|
+
return rearrange(codes, "... h d -> ... (h d)")
|
|
940
|
+
|
|
941
|
+
indices, ps = pack_one(indices, "b * h")
|
|
942
|
+
indices = rearrange(indices, "b n h -> b h n")
|
|
943
|
+
|
|
944
|
+
indices = repeat(indices, "b h n -> b h n d", d=codebook.shape[-1])
|
|
945
|
+
codebook = repeat(codebook, "h n d -> b h n d", b=indices.shape[0])
|
|
946
|
+
|
|
947
|
+
codes = codebook.gather(2, indices)
|
|
948
|
+
codes = rearrange(codes, "b h n d -> b n (h d)")
|
|
949
|
+
codes = unpack_one(codes, ps, "b * d")
|
|
950
|
+
return codes
|
|
951
|
+
|
|
952
|
+
def forward(
|
|
953
|
+
self,
|
|
954
|
+
x,
|
|
955
|
+
indices=None,
|
|
956
|
+
mask=None,
|
|
957
|
+
sample_codebook_temp=None,
|
|
958
|
+
freeze_codebook=False,
|
|
959
|
+
):
|
|
960
|
+
orig_input = x
|
|
961
|
+
|
|
962
|
+
only_one = x.ndim == 2
|
|
963
|
+
|
|
964
|
+
if only_one:
|
|
965
|
+
assert not exists(mask)
|
|
966
|
+
x = rearrange(x, "b d -> b 1 d")
|
|
967
|
+
|
|
968
|
+
shape, device, heads, is_multiheaded, codebook_size, return_loss = (
|
|
969
|
+
x.shape,
|
|
970
|
+
x.device,
|
|
971
|
+
self.heads,
|
|
972
|
+
self.heads > 1,
|
|
973
|
+
self.codebook_size,
|
|
974
|
+
exists(indices),
|
|
975
|
+
)
|
|
976
|
+
|
|
977
|
+
need_transpose = not self.channel_last and not self.accept_image_fmap
|
|
978
|
+
should_inplace_optimize = exists(self.in_place_codebook_optimizer)
|
|
979
|
+
|
|
980
|
+
# rearrange inputs
|
|
981
|
+
|
|
982
|
+
if self.accept_image_fmap:
|
|
983
|
+
height, width = x.shape[-2:]
|
|
984
|
+
x = rearrange(x, "b c h w -> b (h w) c")
|
|
985
|
+
|
|
986
|
+
if need_transpose:
|
|
987
|
+
x = rearrange(x, "b d n -> b n d")
|
|
988
|
+
|
|
989
|
+
# project input
|
|
990
|
+
|
|
991
|
+
x = self.project_in(x)
|
|
992
|
+
|
|
993
|
+
# handle multi-headed separate codebooks
|
|
994
|
+
|
|
995
|
+
if is_multiheaded:
|
|
996
|
+
ein_rhs_eq = "h b n d" if self.separate_codebook_per_head else "1 (b h) n d"
|
|
997
|
+
x = rearrange(x, f"b n (h d) -> {ein_rhs_eq}", h=heads)
|
|
998
|
+
|
|
999
|
+
# l2norm for cosine sim, otherwise identity
|
|
1000
|
+
|
|
1001
|
+
x = self._codebook.transform_input(x)
|
|
1002
|
+
|
|
1003
|
+
# codebook forward kwargs
|
|
1004
|
+
|
|
1005
|
+
codebook_forward_kwargs = dict(
|
|
1006
|
+
sample_codebook_temp=sample_codebook_temp,
|
|
1007
|
+
mask=mask,
|
|
1008
|
+
freeze_codebook=freeze_codebook,
|
|
1009
|
+
)
|
|
1010
|
+
|
|
1011
|
+
# quantize
|
|
1012
|
+
|
|
1013
|
+
quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs)
|
|
1014
|
+
|
|
1015
|
+
# one step in-place update
|
|
1016
|
+
|
|
1017
|
+
if should_inplace_optimize and self.training and not freeze_codebook:
|
|
1018
|
+
if exists(mask):
|
|
1019
|
+
loss = F.mse_loss(quantize, x.detach(), reduction="none")
|
|
1020
|
+
|
|
1021
|
+
loss_mask = mask
|
|
1022
|
+
if is_multiheaded:
|
|
1023
|
+
loss_mask = repeat(
|
|
1024
|
+
mask,
|
|
1025
|
+
"b n -> c (b h) n",
|
|
1026
|
+
c=loss.shape[0],
|
|
1027
|
+
h=loss.shape[1] // mask.shape[0],
|
|
1028
|
+
)
|
|
1029
|
+
|
|
1030
|
+
loss = loss[loss_mask].mean()
|
|
1031
|
+
|
|
1032
|
+
else:
|
|
1033
|
+
loss = F.mse_loss(quantize, x.detach())
|
|
1034
|
+
|
|
1035
|
+
loss.backward()
|
|
1036
|
+
self.in_place_codebook_optimizer.step()
|
|
1037
|
+
self.in_place_codebook_optimizer.zero_grad()
|
|
1038
|
+
|
|
1039
|
+
# quantize again
|
|
1040
|
+
|
|
1041
|
+
quantize, embed_ind, distances = self._codebook(
|
|
1042
|
+
x, **codebook_forward_kwargs
|
|
1043
|
+
)
|
|
1044
|
+
|
|
1045
|
+
if self.training:
|
|
1046
|
+
# determine code to use for commitment loss
|
|
1047
|
+
maybe_detach = (
|
|
1048
|
+
torch.detach
|
|
1049
|
+
if not self.learnable_codebook or freeze_codebook
|
|
1050
|
+
else identity
|
|
1051
|
+
)
|
|
1052
|
+
|
|
1053
|
+
commit_quantize = maybe_detach(quantize)
|
|
1054
|
+
|
|
1055
|
+
# straight through
|
|
1056
|
+
|
|
1057
|
+
quantize = x + (quantize - x).detach()
|
|
1058
|
+
|
|
1059
|
+
if self.sync_update_v > 0.0:
|
|
1060
|
+
# (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf
|
|
1061
|
+
quantize = quantize + self.sync_update_v * (
|
|
1062
|
+
quantize - quantize.detach()
|
|
1063
|
+
)
|
|
1064
|
+
|
|
1065
|
+
# function for calculating cross entropy loss to distance matrix
|
|
1066
|
+
# used for (1) naturalspeech2 training residual vq latents to be close to the correct codes and (2) cross-entropy based commitment loss
|
|
1067
|
+
|
|
1068
|
+
def calculate_ce_loss(codes):
|
|
1069
|
+
if not is_multiheaded:
|
|
1070
|
+
dist_einops_eq = "1 b n l -> b l n"
|
|
1071
|
+
elif self.separate_codebook_per_head:
|
|
1072
|
+
dist_einops_eq = "c b n l -> b l n c"
|
|
1073
|
+
else:
|
|
1074
|
+
dist_einops_eq = "1 (b h) n l -> b l n h"
|
|
1075
|
+
|
|
1076
|
+
ce_loss = F.cross_entropy(
|
|
1077
|
+
rearrange(distances, dist_einops_eq, b=shape[0]), codes, ignore_index=-1
|
|
1078
|
+
)
|
|
1079
|
+
|
|
1080
|
+
return ce_loss
|
|
1081
|
+
|
|
1082
|
+
# if returning cross entropy loss on codes that were passed in
|
|
1083
|
+
|
|
1084
|
+
if return_loss:
|
|
1085
|
+
return quantize, calculate_ce_loss(indices)
|
|
1086
|
+
|
|
1087
|
+
# transform embedding indices
|
|
1088
|
+
|
|
1089
|
+
if is_multiheaded:
|
|
1090
|
+
if self.separate_codebook_per_head:
|
|
1091
|
+
embed_ind = rearrange(embed_ind, "h b n -> b n h", h=heads)
|
|
1092
|
+
else:
|
|
1093
|
+
embed_ind = rearrange(embed_ind, "1 (b h) n -> b n h", h=heads)
|
|
1094
|
+
|
|
1095
|
+
if self.accept_image_fmap:
|
|
1096
|
+
embed_ind = rearrange(
|
|
1097
|
+
embed_ind, "b (h w) ... -> b h w ...", h=height, w=width
|
|
1098
|
+
)
|
|
1099
|
+
|
|
1100
|
+
if only_one:
|
|
1101
|
+
embed_ind = rearrange(embed_ind, "b 1 -> b")
|
|
1102
|
+
|
|
1103
|
+
# aggregate loss
|
|
1104
|
+
|
|
1105
|
+
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
|
|
1106
|
+
|
|
1107
|
+
if self.training:
|
|
1108
|
+
if self.commitment_weight > 0:
|
|
1109
|
+
if self.commitment_use_cross_entropy_loss:
|
|
1110
|
+
if exists(mask):
|
|
1111
|
+
ce_loss_mask = mask
|
|
1112
|
+
if is_multiheaded:
|
|
1113
|
+
ce_loss_mask = repeat(ce_loss_mask, "b n -> b n h", h=heads)
|
|
1114
|
+
|
|
1115
|
+
embed_ind.masked_fill_(~ce_loss_mask, -1)
|
|
1116
|
+
|
|
1117
|
+
commit_loss = calculate_ce_loss(embed_ind)
|
|
1118
|
+
else:
|
|
1119
|
+
if exists(mask):
|
|
1120
|
+
# with variable lengthed sequences
|
|
1121
|
+
commit_loss = F.mse_loss(commit_quantize, x, reduction="none")
|
|
1122
|
+
|
|
1123
|
+
loss_mask = mask
|
|
1124
|
+
if is_multiheaded:
|
|
1125
|
+
loss_mask = repeat(
|
|
1126
|
+
loss_mask,
|
|
1127
|
+
"b n -> c (b h) n",
|
|
1128
|
+
c=commit_loss.shape[0],
|
|
1129
|
+
h=commit_loss.shape[1] // mask.shape[0],
|
|
1130
|
+
)
|
|
1131
|
+
|
|
1132
|
+
commit_loss = commit_loss[loss_mask].mean()
|
|
1133
|
+
else:
|
|
1134
|
+
commit_loss = F.mse_loss(commit_quantize, x)
|
|
1135
|
+
|
|
1136
|
+
loss = loss + commit_loss * self.commitment_weight
|
|
1137
|
+
|
|
1138
|
+
if self.has_codebook_orthogonal_loss:
|
|
1139
|
+
codebook = self._codebook.embed
|
|
1140
|
+
|
|
1141
|
+
# only calculate orthogonal loss for the activated codes for this batch
|
|
1142
|
+
|
|
1143
|
+
if self.orthogonal_reg_active_codes_only:
|
|
1144
|
+
assert not (
|
|
1145
|
+
is_multiheaded and self.separate_codebook_per_head
|
|
1146
|
+
), "orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet"
|
|
1147
|
+
unique_code_ids = torch.unique(embed_ind)
|
|
1148
|
+
codebook = codebook[:, unique_code_ids]
|
|
1149
|
+
|
|
1150
|
+
num_codes = codebook.shape[-2]
|
|
1151
|
+
|
|
1152
|
+
if (
|
|
1153
|
+
exists(self.orthogonal_reg_max_codes)
|
|
1154
|
+
and num_codes > self.orthogonal_reg_max_codes
|
|
1155
|
+
):
|
|
1156
|
+
rand_ids = torch.randperm(num_codes, device=device)[
|
|
1157
|
+
: self.orthogonal_reg_max_codes
|
|
1158
|
+
]
|
|
1159
|
+
codebook = codebook[:, rand_ids]
|
|
1160
|
+
|
|
1161
|
+
orthogonal_reg_loss = orthogonal_loss_fn(codebook)
|
|
1162
|
+
loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
|
|
1163
|
+
|
|
1164
|
+
# handle multi-headed quantized embeddings
|
|
1165
|
+
|
|
1166
|
+
if is_multiheaded:
|
|
1167
|
+
if self.separate_codebook_per_head:
|
|
1168
|
+
quantize = rearrange(quantize, "h b n d -> b n (h d)", h=heads)
|
|
1169
|
+
else:
|
|
1170
|
+
quantize = rearrange(quantize, "1 (b h) n d -> b n (h d)", h=heads)
|
|
1171
|
+
|
|
1172
|
+
# project out
|
|
1173
|
+
|
|
1174
|
+
quantize = self.project_out(quantize)
|
|
1175
|
+
|
|
1176
|
+
# rearrange quantized embeddings
|
|
1177
|
+
|
|
1178
|
+
if need_transpose:
|
|
1179
|
+
quantize = rearrange(quantize, "b n d -> b d n")
|
|
1180
|
+
|
|
1181
|
+
if self.accept_image_fmap:
|
|
1182
|
+
quantize = rearrange(quantize, "b (h w) c -> b c h w", h=height, w=width)
|
|
1183
|
+
|
|
1184
|
+
if only_one:
|
|
1185
|
+
quantize = rearrange(quantize, "b 1 d -> b d")
|
|
1186
|
+
|
|
1187
|
+
# if masking, only return quantized for where mask has True
|
|
1188
|
+
|
|
1189
|
+
if exists(mask):
|
|
1190
|
+
quantize = torch.where(
|
|
1191
|
+
rearrange(mask, "... -> ... 1"), quantize, orig_input
|
|
1192
|
+
)
|
|
1193
|
+
|
|
1194
|
+
return quantize, embed_ind, loss
|