egogym 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. baselines/pi_policy.py +110 -0
  2. baselines/rum/__init__.py +1 -0
  3. baselines/rum/loss_fns/__init__.py +37 -0
  4. baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
  5. baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
  6. baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
  7. baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
  8. baselines/rum/models/__init__.py +1 -0
  9. baselines/rum/models/bet/__init__.py +3 -0
  10. baselines/rum/models/bet/bet.py +347 -0
  11. baselines/rum/models/bet/gpt.py +277 -0
  12. baselines/rum/models/bet/tokenized_bet.py +454 -0
  13. baselines/rum/models/bet/utils.py +124 -0
  14. baselines/rum/models/bet/vqbet.py +410 -0
  15. baselines/rum/models/bet/vqvae/__init__.py +3 -0
  16. baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
  17. baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
  18. baselines/rum/models/bet/vqvae/vqvae.py +313 -0
  19. baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
  20. baselines/rum/models/custom.py +33 -0
  21. baselines/rum/models/encoders/__init__.py +0 -0
  22. baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
  23. baselines/rum/models/encoders/identity.py +45 -0
  24. baselines/rum/models/encoders/timm_encoders.py +82 -0
  25. baselines/rum/models/policies/diffusion_policy.py +881 -0
  26. baselines/rum/models/policies/open_loop.py +122 -0
  27. baselines/rum/models/policies/simple_open_loop.py +108 -0
  28. baselines/rum/molmo/server.py +144 -0
  29. baselines/rum/policy.py +293 -0
  30. baselines/rum/utils/__init__.py +212 -0
  31. baselines/rum/utils/action_transforms.py +22 -0
  32. baselines/rum/utils/decord_transforms.py +135 -0
  33. baselines/rum/utils/rpc.py +249 -0
  34. baselines/rum/utils/schedulers.py +71 -0
  35. baselines/rum/utils/trajectory_vis.py +128 -0
  36. baselines/rum/utils/zmq_utils.py +281 -0
  37. baselines/rum_policy.py +108 -0
  38. egogym/__init__.py +8 -0
  39. egogym/assets/constants.py +1804 -0
  40. egogym/components/__init__.py +1 -0
  41. egogym/components/object.py +94 -0
  42. egogym/egogym.py +106 -0
  43. egogym/embodiments/__init__.py +10 -0
  44. egogym/embodiments/arms/__init__.py +4 -0
  45. egogym/embodiments/arms/arm.py +65 -0
  46. egogym/embodiments/arms/droid.py +49 -0
  47. egogym/embodiments/grippers/__init__.py +4 -0
  48. egogym/embodiments/grippers/floating_gripper.py +58 -0
  49. egogym/embodiments/grippers/rum.py +6 -0
  50. egogym/embodiments/robot.py +95 -0
  51. egogym/evaluate.py +216 -0
  52. egogym/managers/__init__.py +2 -0
  53. egogym/managers/objects_managers.py +30 -0
  54. egogym/managers/textures_manager.py +21 -0
  55. egogym/misc/molmo_client.py +49 -0
  56. egogym/misc/molmo_server.py +197 -0
  57. egogym/policies/__init__.py +1 -0
  58. egogym/policies/base_policy.py +13 -0
  59. egogym/scripts/analayze.py +834 -0
  60. egogym/scripts/plot.py +87 -0
  61. egogym/scripts/plot_correlation.py +392 -0
  62. egogym/scripts/plot_correlation_hardcoded.py +338 -0
  63. egogym/scripts/plot_failure.py +248 -0
  64. egogym/scripts/plot_failure_hardcoded.py +195 -0
  65. egogym/scripts/plot_failure_vlm.py +257 -0
  66. egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
  67. egogym/scripts/plot_line.py +303 -0
  68. egogym/scripts/plot_line_hardcoded.py +285 -0
  69. egogym/scripts/plot_pi0_bars.py +169 -0
  70. egogym/tasks/close.py +84 -0
  71. egogym/tasks/open.py +85 -0
  72. egogym/tasks/pick.py +121 -0
  73. egogym/utils.py +969 -0
  74. egogym/wrappers/__init__.py +20 -0
  75. egogym/wrappers/episode_monitor.py +282 -0
  76. egogym/wrappers/unprivileged_chatgpt.py +163 -0
  77. egogym/wrappers/unprivileged_gemini.py +157 -0
  78. egogym/wrappers/unprivileged_molmo.py +88 -0
  79. egogym/wrappers/unprivileged_moondream.py +121 -0
  80. egogym-0.1.0.dist-info/METADATA +52 -0
  81. egogym-0.1.0.dist-info/RECORD +83 -0
  82. egogym-0.1.0.dist-info/WHEEL +5 -0
  83. egogym-0.1.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,1194 @@
1
+ from functools import partial
2
+
3
+ import torch
4
+ from torch import nn, einsum
5
+ import torch.nn.functional as F
6
+ import torch.distributed as distributed
7
+ from torch.optim import Optimizer
8
+ from torch.cuda.amp import autocast
9
+
10
+ from einops import rearrange, repeat, reduce, pack, unpack
11
+
12
+ from typing import Callable
13
+
14
+
15
+ def exists(val):
16
+ return val is not None
17
+
18
+
19
+ def default(val, d):
20
+ return val if exists(val) else d
21
+
22
+
23
+ def noop(*args, **kwargs):
24
+ pass
25
+
26
+
27
+ def identity(t):
28
+ return t
29
+
30
+
31
+ def l2norm(t):
32
+ return F.normalize(t, p=2, dim=-1)
33
+
34
+
35
+ def cdist(x, y):
36
+ x2 = reduce(x**2, "b n d -> b n", "sum")
37
+ y2 = reduce(y**2, "b n d -> b n", "sum")
38
+ xy = einsum("b i d, b j d -> b i j", x, y) * -2
39
+ return (rearrange(x2, "b i -> b i 1") + rearrange(y2, "b j -> b 1 j") + xy).sqrt()
40
+
41
+
42
+ def log(t, eps=1e-20):
43
+ return torch.log(t.clamp(min=eps))
44
+
45
+
46
+ def ema_inplace(old, new, decay):
47
+ is_mps = str(old.device).startswith("mps:")
48
+
49
+ if not is_mps:
50
+ old.lerp_(new, 1 - decay)
51
+ else:
52
+ old.mul_(decay).add_(new * (1 - decay))
53
+
54
+
55
+ def pack_one(t, pattern):
56
+ return pack([t], pattern)
57
+
58
+
59
+ def unpack_one(t, ps, pattern):
60
+ return unpack(t, ps, pattern)[0]
61
+
62
+
63
+ def uniform_init(*shape):
64
+ t = torch.empty(shape)
65
+ nn.init.kaiming_uniform_(t)
66
+ return t
67
+
68
+
69
+ def gumbel_noise(t):
70
+ noise = torch.zeros_like(t).uniform_(0, 1)
71
+ return -log(-log(noise))
72
+
73
+
74
+ def gumbel_sample(
75
+ logits,
76
+ temperature=1.0,
77
+ stochastic=False,
78
+ straight_through=False,
79
+ reinmax=False,
80
+ dim=-1,
81
+ training=True,
82
+ ):
83
+ dtype, size = logits.dtype, logits.shape[dim]
84
+
85
+ if training and stochastic and temperature > 0:
86
+ sampling_logits = (logits / temperature) + gumbel_noise(logits)
87
+ else:
88
+ sampling_logits = logits
89
+
90
+ ind = sampling_logits.argmax(dim=dim)
91
+ one_hot = F.one_hot(ind, size).type(dtype)
92
+
93
+ assert not (
94
+ reinmax and not straight_through
95
+ ), "reinmax can only be turned on if using straight through gumbel softmax"
96
+
97
+ if not straight_through or temperature <= 0.0 or not training:
98
+ return ind, one_hot
99
+
100
+ # use reinmax for better second-order accuracy - https://arxiv.org/abs/2304.08612
101
+ # algorithm 2
102
+
103
+ if reinmax:
104
+ π0 = logits.softmax(dim=dim)
105
+ π1 = (one_hot + (logits / temperature).softmax(dim=dim)) / 2
106
+ π1 = ((log(π1) - logits).detach() + logits).softmax(dim=1)
107
+ π2 = 2 * π1 - 0.5 * π0
108
+ one_hot = π2 - π2.detach() + one_hot
109
+ else:
110
+ π1 = (logits / temperature).softmax(dim=dim)
111
+ one_hot = one_hot + π1 - π1.detach()
112
+
113
+ return ind, one_hot
114
+
115
+
116
+ def laplace_smoothing(x, n_categories, eps=1e-5, dim=-1):
117
+ denom = x.sum(dim=dim, keepdim=True)
118
+ return (x + eps) / (denom + n_categories * eps)
119
+
120
+
121
+ def sample_vectors(samples, num):
122
+ num_samples, device = samples.shape[0], samples.device
123
+ if num_samples >= num:
124
+ indices = torch.randperm(num_samples, device=device)[:num]
125
+ else:
126
+ indices = torch.randint(0, num_samples, (num,), device=device)
127
+
128
+ return samples[indices]
129
+
130
+
131
+ def batched_sample_vectors(samples, num):
132
+ return torch.stack(
133
+ [sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0
134
+ )
135
+
136
+
137
+ def pad_shape(shape, size, dim=0):
138
+ return [size if i == dim else s for i, s in enumerate(shape)]
139
+
140
+
141
+ def sample_multinomial(total_count, probs):
142
+ device = probs.device
143
+ probs = probs.cpu()
144
+
145
+ total_count = probs.new_full((), total_count)
146
+ remainder = probs.new_ones(())
147
+ sample = torch.empty_like(probs, dtype=torch.long)
148
+
149
+ for i, p in enumerate(probs):
150
+ s = torch.binomial(total_count, p / remainder)
151
+ sample[i] = s
152
+ total_count -= s
153
+ remainder -= p
154
+
155
+ return sample.to(device)
156
+
157
+
158
+ def all_gather_sizes(x, dim):
159
+ size = torch.tensor(x.shape[dim], dtype=torch.long, device=x.device)
160
+ all_sizes = [torch.empty_like(size) for _ in range(distributed.get_world_size())]
161
+ distributed.all_gather(all_sizes, size)
162
+ return torch.stack(all_sizes)
163
+
164
+
165
+ def all_gather_variably_sized(x, sizes, dim=0):
166
+ rank = distributed.get_rank()
167
+ all_x = []
168
+
169
+ for i, size in enumerate(sizes):
170
+ t = x if i == rank else x.new_empty(pad_shape(x.shape, size, dim))
171
+ distributed.broadcast(t, src=i, async_op=True)
172
+ all_x.append(t)
173
+
174
+ distributed.barrier()
175
+ return all_x
176
+
177
+
178
+ def sample_vectors_distributed(local_samples, num):
179
+ local_samples = rearrange(local_samples, "1 ... -> ...")
180
+
181
+ rank = distributed.get_rank()
182
+ all_num_samples = all_gather_sizes(local_samples, dim=0)
183
+
184
+ if rank == 0:
185
+ samples_per_rank = sample_multinomial(
186
+ num, all_num_samples / all_num_samples.sum()
187
+ )
188
+ else:
189
+ samples_per_rank = torch.empty_like(all_num_samples)
190
+
191
+ distributed.broadcast(samples_per_rank, src=0)
192
+ samples_per_rank = samples_per_rank.tolist()
193
+
194
+ local_samples = sample_vectors(local_samples, samples_per_rank[rank])
195
+ all_samples = all_gather_variably_sized(local_samples, samples_per_rank, dim=0)
196
+ out = torch.cat(all_samples, dim=0)
197
+
198
+ return rearrange(out, "... -> 1 ...")
199
+
200
+
201
+ def batched_bincount(x, *, minlength):
202
+ batch, dtype, device = x.shape[0], x.dtype, x.device
203
+ target = torch.zeros(batch, minlength, dtype=dtype, device=device)
204
+ values = torch.ones_like(x)
205
+ target.scatter_add_(-1, x, values)
206
+ return target
207
+
208
+
209
+ def kmeans(
210
+ samples,
211
+ num_clusters,
212
+ num_iters=10,
213
+ use_cosine_sim=False,
214
+ sample_fn=batched_sample_vectors,
215
+ all_reduce_fn=noop,
216
+ ):
217
+ num_codebooks, dim, dtype, device = (
218
+ samples.shape[0],
219
+ samples.shape[-1],
220
+ samples.dtype,
221
+ samples.device,
222
+ )
223
+
224
+ means = sample_fn(samples, num_clusters)
225
+
226
+ for _ in range(num_iters):
227
+ if use_cosine_sim:
228
+ dists = samples @ rearrange(means, "h n d -> h d n")
229
+ else:
230
+ dists = -torch.cdist(samples, means, p=2)
231
+
232
+ buckets = torch.argmax(dists, dim=-1)
233
+ bins = batched_bincount(buckets, minlength=num_clusters)
234
+ all_reduce_fn(bins)
235
+
236
+ zero_mask = bins == 0
237
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
238
+
239
+ new_means = buckets.new_zeros(num_codebooks, num_clusters, dim, dtype=dtype)
240
+
241
+ new_means.scatter_add_(1, repeat(buckets, "h n -> h n d", d=dim), samples)
242
+ new_means = new_means / rearrange(bins_min_clamped, "... -> ... 1")
243
+ all_reduce_fn(new_means)
244
+
245
+ if use_cosine_sim:
246
+ new_means = l2norm(new_means)
247
+
248
+ means = torch.where(rearrange(zero_mask, "... -> ... 1"), means, new_means)
249
+
250
+ return means, bins
251
+
252
+
253
+ def batched_embedding(indices, embeds):
254
+ batch, dim = indices.shape[1], embeds.shape[-1]
255
+ indices = repeat(indices, "h b n -> h b n d", d=dim)
256
+ embeds = repeat(embeds, "h c d -> h b c d", b=batch)
257
+ return embeds.gather(2, indices)
258
+
259
+
260
+ # regularization losses
261
+
262
+
263
+ def orthogonal_loss_fn(t):
264
+ # eq (2) from https://arxiv.org/abs/2112.00384
265
+ h, n = t.shape[:2]
266
+ normed_codes = l2norm(t)
267
+ cosine_sim = einsum("h i d, h j d -> h i j", normed_codes, normed_codes)
268
+ return (cosine_sim**2).sum() / (h * n**2) - (1 / n)
269
+
270
+
271
+ # distance types
272
+
273
+
274
+ class EuclideanCodebook(nn.Module):
275
+ def __init__(
276
+ self,
277
+ dim,
278
+ codebook_size,
279
+ num_codebooks=1,
280
+ kmeans_init=False,
281
+ kmeans_iters=10,
282
+ sync_kmeans=True,
283
+ decay=0.8,
284
+ eps=1e-5,
285
+ threshold_ema_dead_code=2,
286
+ reset_cluster_size=None,
287
+ use_ddp=False,
288
+ learnable_codebook=False,
289
+ gumbel_sample=gumbel_sample,
290
+ sample_codebook_temp=1.0,
291
+ ema_update=True,
292
+ affine_param=False,
293
+ sync_affine_param=False,
294
+ affine_param_batch_decay=0.99,
295
+ affine_param_codebook_decay=0.9,
296
+ ):
297
+ super().__init__()
298
+ self.transform_input = identity
299
+
300
+ self.decay = decay
301
+ self.ema_update = ema_update
302
+
303
+ init_fn = uniform_init if not kmeans_init else torch.zeros
304
+ embed = init_fn(num_codebooks, codebook_size, dim)
305
+
306
+ self.codebook_size = codebook_size
307
+ self.num_codebooks = num_codebooks
308
+
309
+ self.kmeans_iters = kmeans_iters
310
+ self.eps = eps
311
+ self.threshold_ema_dead_code = threshold_ema_dead_code
312
+ self.reset_cluster_size = default(reset_cluster_size, threshold_ema_dead_code)
313
+
314
+ assert callable(gumbel_sample)
315
+ self.gumbel_sample = gumbel_sample
316
+ self.sample_codebook_temp = sample_codebook_temp
317
+
318
+ assert not (
319
+ use_ddp and num_codebooks > 1 and kmeans_init
320
+ ), "kmeans init is not compatible with multiple codebooks in distributed environment for now"
321
+
322
+ self.sample_fn = (
323
+ sample_vectors_distributed
324
+ if use_ddp and sync_kmeans
325
+ else batched_sample_vectors
326
+ )
327
+ self.kmeans_all_reduce_fn = (
328
+ distributed.all_reduce if use_ddp and sync_kmeans else noop
329
+ )
330
+ self.all_reduce_fn = distributed.all_reduce if use_ddp else noop
331
+
332
+ self.register_buffer("initted", torch.Tensor([not kmeans_init]))
333
+ self.register_buffer("cluster_size", torch.zeros(num_codebooks, codebook_size))
334
+ self.register_buffer("embed_avg", embed.clone())
335
+
336
+ self.learnable_codebook = learnable_codebook
337
+ if learnable_codebook:
338
+ self.embed = nn.Parameter(embed)
339
+ else:
340
+ self.register_buffer("embed", embed)
341
+
342
+ # affine related params
343
+
344
+ self.affine_param = affine_param
345
+ self.sync_affine_param = sync_affine_param
346
+
347
+ if not affine_param:
348
+ return
349
+
350
+ self.affine_param_batch_decay = affine_param_batch_decay
351
+ self.affine_param_codebook_decay = affine_param_codebook_decay
352
+
353
+ self.register_buffer("batch_mean", None)
354
+ self.register_buffer("batch_variance", None)
355
+
356
+ self.register_buffer("codebook_mean_needs_init", torch.Tensor([True]))
357
+ self.register_buffer("codebook_mean", torch.empty(num_codebooks, 1, dim))
358
+ self.register_buffer("codebook_variance_needs_init", torch.Tensor([True]))
359
+ self.register_buffer("codebook_variance", torch.empty(num_codebooks, 1, dim))
360
+
361
+ @torch.jit.ignore
362
+ def init_embed_(self, data, mask=None):
363
+ if self.initted:
364
+ return
365
+
366
+ if exists(mask):
367
+ c = data.shape[0]
368
+ data = rearrange(data[mask], "(c n) d -> c n d", c=c)
369
+
370
+ embed, cluster_size = kmeans(
371
+ data,
372
+ self.codebook_size,
373
+ self.kmeans_iters,
374
+ sample_fn=self.sample_fn,
375
+ all_reduce_fn=self.kmeans_all_reduce_fn,
376
+ )
377
+
378
+ embed_sum = embed * rearrange(cluster_size, "... -> ... 1")
379
+
380
+ self.embed.data.copy_(embed)
381
+ self.embed_avg.data.copy_(embed_sum)
382
+ self.cluster_size.data.copy_(cluster_size)
383
+ self.initted.data.copy_(torch.Tensor([True]))
384
+
385
+ @torch.jit.ignore
386
+ def update_with_decay(self, buffer_name, new_value, decay):
387
+ old_value = getattr(self, buffer_name)
388
+
389
+ needs_init = getattr(self, buffer_name + "_needs_init", False)
390
+
391
+ if needs_init:
392
+ self.register_buffer(buffer_name + "_needs_init", torch.Tensor([False]))
393
+
394
+ if not exists(old_value) or needs_init:
395
+ self.register_buffer(buffer_name, new_value.detach())
396
+
397
+ return
398
+
399
+ value = old_value * decay + new_value.detach() * (1 - decay)
400
+ self.register_buffer(buffer_name, value)
401
+
402
+ @torch.jit.ignore
403
+ def update_affine(self, data, embed, mask=None):
404
+ assert self.affine_param
405
+
406
+ var_fn = partial(torch.var, unbiased=False)
407
+
408
+ # calculate codebook mean and variance
409
+
410
+ embed = rearrange(embed, "h ... d -> h (...) d")
411
+
412
+ if self.training:
413
+ self.update_with_decay(
414
+ "codebook_mean",
415
+ reduce(embed, "h n d -> h 1 d", "mean"),
416
+ self.affine_param_codebook_decay,
417
+ )
418
+ self.update_with_decay(
419
+ "codebook_variance",
420
+ reduce(embed, "h n d -> h 1 d", var_fn),
421
+ self.affine_param_codebook_decay,
422
+ )
423
+
424
+ # prepare batch data, which depends on whether it has masking
425
+
426
+ data = rearrange(data, "h ... d -> h (...) d")
427
+
428
+ if exists(mask):
429
+ c = data.shape[0]
430
+ data = rearrange(data[mask], "(c n) d -> c n d", c=c)
431
+
432
+ # calculate batch mean and variance
433
+
434
+ if not self.sync_affine_param:
435
+ self.update_with_decay(
436
+ "batch_mean",
437
+ reduce(data, "h n d -> h 1 d", "mean"),
438
+ self.affine_param_batch_decay,
439
+ )
440
+ self.update_with_decay(
441
+ "batch_variance",
442
+ reduce(data, "h n d -> h 1 d", var_fn),
443
+ self.affine_param_batch_decay,
444
+ )
445
+ return
446
+
447
+ num_vectors, device, dtype = data.shape[-2], data.device, data.dtype
448
+
449
+ # number of vectors, for denominator
450
+
451
+ num_vectors = torch.tensor([num_vectors], device=device, dtype=dtype)
452
+ distributed.all_reduce(num_vectors)
453
+
454
+ # calculate distributed mean
455
+
456
+ batch_sum = reduce(data, "h n d -> h 1 d", "sum")
457
+ distributed.all_reduce(batch_sum)
458
+ batch_mean = batch_sum / num_vectors
459
+
460
+ self.update_with_decay("batch_mean", batch_mean, self.affine_param_batch_decay)
461
+
462
+ # calculate distributed variance
463
+
464
+ variance_numer = reduce((data - batch_mean) ** 2, "h n d -> h 1 d", "sum")
465
+ distributed.all_reduce(variance_numer)
466
+ batch_variance = variance_numer / num_vectors
467
+
468
+ self.update_with_decay(
469
+ "batch_variance", batch_variance, self.affine_param_batch_decay
470
+ )
471
+
472
+ def replace(self, batch_samples, batch_mask):
473
+ for ind, (samples, mask) in enumerate(
474
+ zip(batch_samples.unbind(dim=0), batch_mask.unbind(dim=0))
475
+ ):
476
+ if not torch.any(mask):
477
+ continue
478
+
479
+ sampled = self.sample_fn(
480
+ rearrange(samples, "... -> 1 ..."), mask.sum().item()
481
+ )
482
+ sampled = rearrange(sampled, "1 ... -> ...")
483
+
484
+ self.embed.data[ind][mask] = sampled
485
+
486
+ self.cluster_size.data[ind][mask] = self.reset_cluster_size
487
+ self.embed_avg.data[ind][mask] = sampled * self.reset_cluster_size
488
+
489
+ def expire_codes_(self, batch_samples):
490
+ if self.threshold_ema_dead_code == 0:
491
+ return
492
+
493
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
494
+
495
+ if not torch.any(expired_codes):
496
+ return
497
+
498
+ batch_samples = rearrange(batch_samples, "h ... d -> h (...) d")
499
+ self.replace(batch_samples, batch_mask=expired_codes)
500
+
501
+ @autocast(enabled=False)
502
+ def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False):
503
+ needs_codebook_dim = x.ndim < 4
504
+ sample_codebook_temp = default(sample_codebook_temp, self.sample_codebook_temp)
505
+
506
+ x = x.float()
507
+
508
+ if needs_codebook_dim:
509
+ x = rearrange(x, "... -> 1 ...")
510
+
511
+ dtype = x.dtype
512
+ flatten, ps = pack_one(x, "h * d")
513
+
514
+ if exists(mask):
515
+ mask = repeat(
516
+ mask,
517
+ "b n -> c (b h n)",
518
+ c=flatten.shape[0],
519
+ h=flatten.shape[-2] // (mask.shape[0] * mask.shape[1]),
520
+ )
521
+
522
+ self.init_embed_(flatten, mask=mask)
523
+
524
+ if self.affine_param:
525
+ self.update_affine(flatten, self.embed, mask=mask)
526
+
527
+ embed = self.embed if self.learnable_codebook else self.embed.detach()
528
+
529
+ if self.affine_param:
530
+ codebook_std = self.codebook_variance.clamp(min=1e-5).sqrt()
531
+ batch_std = self.batch_variance.clamp(min=1e-5).sqrt()
532
+ embed = (embed - self.codebook_mean) * (
533
+ batch_std / codebook_std
534
+ ) + self.batch_mean
535
+
536
+ dist = -cdist(flatten, embed)
537
+
538
+ embed_ind, embed_onehot = self.gumbel_sample(
539
+ dist, dim=-1, temperature=sample_codebook_temp, training=self.training
540
+ )
541
+
542
+ embed_ind = unpack_one(embed_ind, ps, "h *")
543
+
544
+ if self.training:
545
+ unpacked_onehot = unpack_one(embed_onehot, ps, "h * c")
546
+ quantize = einsum("h b n c, h c d -> h b n d", unpacked_onehot, embed)
547
+ else:
548
+ quantize = batched_embedding(embed_ind, embed)
549
+
550
+ if self.training and self.ema_update and not freeze_codebook:
551
+ if self.affine_param:
552
+ flatten = (flatten - self.batch_mean) * (
553
+ codebook_std / batch_std
554
+ ) + self.codebook_mean
555
+
556
+ if exists(mask):
557
+ embed_onehot[~mask] = 0.0
558
+
559
+ cluster_size = embed_onehot.sum(dim=1)
560
+
561
+ self.all_reduce_fn(cluster_size)
562
+ ema_inplace(self.cluster_size.data, cluster_size, self.decay)
563
+
564
+ embed_sum = einsum("h n d, h n c -> h c d", flatten, embed_onehot)
565
+ self.all_reduce_fn(embed_sum.contiguous())
566
+ ema_inplace(self.embed_avg.data, embed_sum, self.decay)
567
+
568
+ cluster_size = laplace_smoothing(
569
+ self.cluster_size, self.codebook_size, self.eps
570
+ ) * self.cluster_size.sum(dim=-1, keepdim=True)
571
+
572
+ embed_normalized = self.embed_avg / rearrange(cluster_size, "... -> ... 1")
573
+ self.embed.data.copy_(embed_normalized)
574
+ self.expire_codes_(x)
575
+
576
+ if needs_codebook_dim:
577
+ quantize, embed_ind = map(
578
+ lambda t: rearrange(t, "1 ... -> ..."), (quantize, embed_ind)
579
+ )
580
+
581
+ dist = unpack_one(dist, ps, "h * d")
582
+
583
+ return quantize, embed_ind, dist
584
+
585
+
586
+ class CosineSimCodebook(nn.Module):
587
+ def __init__(
588
+ self,
589
+ dim,
590
+ codebook_size,
591
+ num_codebooks=1,
592
+ kmeans_init=False,
593
+ kmeans_iters=10,
594
+ sync_kmeans=True,
595
+ decay=0.8,
596
+ eps=1e-5,
597
+ threshold_ema_dead_code=2,
598
+ reset_cluster_size=None,
599
+ use_ddp=False,
600
+ learnable_codebook=False,
601
+ gumbel_sample=gumbel_sample,
602
+ sample_codebook_temp=1.0,
603
+ ema_update=True,
604
+ ):
605
+ super().__init__()
606
+ self.transform_input = l2norm
607
+
608
+ self.ema_update = ema_update
609
+ self.decay = decay
610
+
611
+ if not kmeans_init:
612
+ embed = l2norm(uniform_init(num_codebooks, codebook_size, dim))
613
+ else:
614
+ embed = torch.zeros(num_codebooks, codebook_size, dim)
615
+
616
+ self.codebook_size = codebook_size
617
+ self.num_codebooks = num_codebooks
618
+
619
+ self.kmeans_iters = kmeans_iters
620
+ self.eps = eps
621
+ self.threshold_ema_dead_code = threshold_ema_dead_code
622
+ self.reset_cluster_size = default(reset_cluster_size, threshold_ema_dead_code)
623
+
624
+ assert callable(gumbel_sample)
625
+ self.gumbel_sample = gumbel_sample
626
+ self.sample_codebook_temp = sample_codebook_temp
627
+
628
+ self.sample_fn = (
629
+ sample_vectors_distributed
630
+ if use_ddp and sync_kmeans
631
+ else batched_sample_vectors
632
+ )
633
+ self.kmeans_all_reduce_fn = (
634
+ distributed.all_reduce if use_ddp and sync_kmeans else noop
635
+ )
636
+ self.all_reduce_fn = distributed.all_reduce if use_ddp else noop
637
+
638
+ self.register_buffer("initted", torch.Tensor([not kmeans_init]))
639
+ self.register_buffer("cluster_size", torch.zeros(num_codebooks, codebook_size))
640
+ self.register_buffer("embed_avg", embed.clone())
641
+
642
+ self.learnable_codebook = learnable_codebook
643
+ if learnable_codebook:
644
+ self.embed = nn.Parameter(embed)
645
+ else:
646
+ self.register_buffer("embed", embed)
647
+
648
+ @torch.jit.ignore
649
+ def init_embed_(self, data, mask=None):
650
+ if self.initted:
651
+ return
652
+
653
+ if exists(mask):
654
+ c = data.shape[0]
655
+ data = rearrange(data[mask], "(c n) d -> c n d", c=c)
656
+
657
+ embed, cluster_size = kmeans(
658
+ data,
659
+ self.codebook_size,
660
+ self.kmeans_iters,
661
+ use_cosine_sim=True,
662
+ sample_fn=self.sample_fn,
663
+ all_reduce_fn=self.kmeans_all_reduce_fn,
664
+ )
665
+
666
+ embed_sum = embed * rearrange(cluster_size, "... -> ... 1")
667
+
668
+ self.embed.data.copy_(embed)
669
+ self.embed_avg.data.copy_(embed_sum)
670
+ self.cluster_size.data.copy_(cluster_size)
671
+ self.initted.data.copy_(torch.Tensor([True]))
672
+
673
+ def replace(self, batch_samples, batch_mask):
674
+ batch_samples = l2norm(batch_samples)
675
+
676
+ for ind, (samples, mask) in enumerate(
677
+ zip(batch_samples.unbind(dim=0), batch_mask.unbind(dim=0))
678
+ ):
679
+ if not torch.any(mask):
680
+ continue
681
+
682
+ sampled = self.sample_fn(
683
+ rearrange(samples, "... -> 1 ..."), mask.sum().item()
684
+ )
685
+ sampled = rearrange(sampled, "1 ... -> ...")
686
+
687
+ self.embed.data[ind][mask] = sampled
688
+ self.embed_avg.data[ind][mask] = sampled * self.reset_cluster_size
689
+ self.cluster_size.data[ind][mask] = self.reset_cluster_size
690
+
691
+ def expire_codes_(self, batch_samples):
692
+ if self.threshold_ema_dead_code == 0:
693
+ return
694
+
695
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
696
+
697
+ if not torch.any(expired_codes):
698
+ return
699
+
700
+ batch_samples = rearrange(batch_samples, "h ... d -> h (...) d")
701
+ self.replace(batch_samples, batch_mask=expired_codes)
702
+
703
+ @autocast(enabled=False)
704
+ def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False):
705
+ needs_codebook_dim = x.ndim < 4
706
+ sample_codebook_temp = default(sample_codebook_temp, self.sample_codebook_temp)
707
+
708
+ x = x.float()
709
+
710
+ if needs_codebook_dim:
711
+ x = rearrange(x, "... -> 1 ...")
712
+
713
+ dtype = x.dtype
714
+
715
+ flatten, ps = pack_one(x, "h * d")
716
+
717
+ if exists(mask):
718
+ mask = repeat(
719
+ mask,
720
+ "b n -> c (b h n)",
721
+ c=flatten.shape[0],
722
+ h=flatten.shape[-2] // (mask.shape[0] * mask.shape[1]),
723
+ )
724
+
725
+ self.init_embed_(flatten, mask=mask)
726
+
727
+ embed = self.embed if self.learnable_codebook else self.embed.detach()
728
+
729
+ dist = einsum("h n d, h c d -> h n c", flatten, embed)
730
+
731
+ embed_ind, embed_onehot = self.gumbel_sample(
732
+ dist, dim=-1, temperature=sample_codebook_temp, training=self.training
733
+ )
734
+ embed_ind = unpack_one(embed_ind, ps, "h *")
735
+
736
+ if self.training:
737
+ unpacked_onehot = unpack_one(embed_onehot, ps, "h * c")
738
+ quantize = einsum("h b n c, h c d -> h b n d", unpacked_onehot, embed)
739
+ else:
740
+ quantize = batched_embedding(embed_ind, embed)
741
+
742
+ if self.training and self.ema_update and not freeze_codebook:
743
+ if exists(mask):
744
+ embed_onehot[~mask] = 0.0
745
+
746
+ bins = embed_onehot.sum(dim=1)
747
+ self.all_reduce_fn(bins)
748
+
749
+ ema_inplace(self.cluster_size.data, bins, self.decay)
750
+
751
+ embed_sum = einsum("h n d, h n c -> h c d", flatten, embed_onehot)
752
+ self.all_reduce_fn(embed_sum.contiguous())
753
+ ema_inplace(self.embed_avg.data, embed_sum, self.decay)
754
+
755
+ cluster_size = laplace_smoothing(
756
+ self.cluster_size, self.codebook_size, self.eps
757
+ ) * self.cluster_size.sum(dim=-1, keepdim=True)
758
+
759
+ embed_normalized = self.embed_avg / rearrange(cluster_size, "... -> ... 1")
760
+ embed_normalized = l2norm(embed_normalized)
761
+
762
+ self.embed.data.copy_(l2norm(embed_normalized))
763
+ self.expire_codes_(x)
764
+
765
+ if needs_codebook_dim:
766
+ quantize, embed_ind = map(
767
+ lambda t: rearrange(t, "1 ... -> ..."), (quantize, embed_ind)
768
+ )
769
+
770
+ dist = unpack_one(dist, ps, "h * d")
771
+ return quantize, embed_ind, dist
772
+
773
+
774
+ # main class
775
+
776
+
777
+ class VectorQuantize(nn.Module):
778
+ def __init__(
779
+ self,
780
+ dim,
781
+ codebook_size,
782
+ codebook_dim=None,
783
+ heads=1,
784
+ separate_codebook_per_head=False,
785
+ decay=0.8,
786
+ eps=1e-5,
787
+ freeze_codebook=False,
788
+ kmeans_init=True,
789
+ kmeans_iters=10,
790
+ sync_kmeans=True,
791
+ use_cosine_sim=False,
792
+ threshold_ema_dead_code=2,
793
+ channel_last=True,
794
+ accept_image_fmap=False,
795
+ commitment_weight=1.0,
796
+ commitment_use_cross_entropy_loss=False,
797
+ orthogonal_reg_weight=0.0,
798
+ orthogonal_reg_active_codes_only=False,
799
+ orthogonal_reg_max_codes=None,
800
+ stochastic_sample_codes=False,
801
+ sample_codebook_temp=1.0,
802
+ straight_through=False,
803
+ reinmax=False, # using reinmax for improved straight-through, assuming straight through helps at all
804
+ sync_codebook=None,
805
+ sync_affine_param=False,
806
+ ema_update=True,
807
+ learnable_codebook=False,
808
+ in_place_codebook_optimizer: Callable[
809
+ ..., Optimizer
810
+ ] = None, # Optimizer used to update the codebook embedding if using learnable_codebook
811
+ affine_param=False,
812
+ affine_param_batch_decay=0.99,
813
+ affine_param_codebook_decay=0.9,
814
+ sync_update_v=0.0, # the v that controls optimistic vs pessimistic update for synchronous update rule (21) https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf
815
+ eval=False,
816
+ ):
817
+ super().__init__()
818
+ if eval:
819
+ kmeans_init = False
820
+ ema_update = False
821
+ self.dim = dim
822
+ self.heads = heads
823
+ self.separate_codebook_per_head = separate_codebook_per_head
824
+
825
+ codebook_dim = default(codebook_dim, dim)
826
+ codebook_input_dim = codebook_dim * heads
827
+
828
+ requires_projection = codebook_input_dim != dim
829
+ self.project_in = (
830
+ nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
831
+ )
832
+ self.project_out = (
833
+ nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
834
+ )
835
+
836
+ self.eps = eps
837
+ self.commitment_weight = commitment_weight
838
+ self.commitment_use_cross_entropy_loss = commitment_use_cross_entropy_loss # whether to use cross entropy loss to codebook as commitment loss
839
+
840
+ self.learnable_codebook = learnable_codebook
841
+
842
+ has_codebook_orthogonal_loss = orthogonal_reg_weight > 0
843
+ self.has_codebook_orthogonal_loss = has_codebook_orthogonal_loss
844
+ self.orthogonal_reg_weight = orthogonal_reg_weight
845
+ self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
846
+ self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
847
+
848
+ assert not (
849
+ ema_update and learnable_codebook
850
+ ), "learnable codebook not compatible with EMA update"
851
+
852
+ assert 0 <= sync_update_v <= 1.0
853
+ assert not (
854
+ sync_update_v > 0.0 and not learnable_codebook
855
+ ), "learnable codebook must be turned on"
856
+
857
+ self.sync_update_v = sync_update_v
858
+
859
+ codebook_class = EuclideanCodebook if not use_cosine_sim else CosineSimCodebook
860
+
861
+ gumbel_sample_fn = partial(
862
+ gumbel_sample,
863
+ stochastic=stochastic_sample_codes,
864
+ reinmax=reinmax,
865
+ straight_through=straight_through,
866
+ )
867
+
868
+ if not exists(sync_codebook):
869
+ sync_codebook = (
870
+ distributed.is_initialized() and distributed.get_world_size() > 1
871
+ )
872
+
873
+ codebook_kwargs = dict(
874
+ dim=codebook_dim,
875
+ num_codebooks=heads if separate_codebook_per_head else 1,
876
+ codebook_size=codebook_size,
877
+ kmeans_init=kmeans_init,
878
+ kmeans_iters=kmeans_iters,
879
+ sync_kmeans=sync_kmeans,
880
+ decay=decay,
881
+ eps=eps,
882
+ threshold_ema_dead_code=threshold_ema_dead_code,
883
+ use_ddp=sync_codebook,
884
+ learnable_codebook=has_codebook_orthogonal_loss or learnable_codebook,
885
+ sample_codebook_temp=sample_codebook_temp,
886
+ gumbel_sample=gumbel_sample_fn,
887
+ ema_update=ema_update,
888
+ )
889
+
890
+ if affine_param:
891
+ assert (
892
+ not use_cosine_sim
893
+ ), "affine param is only compatible with euclidean codebook"
894
+ codebook_kwargs = dict(
895
+ **codebook_kwargs,
896
+ affine_param=True,
897
+ sync_affine_param=sync_affine_param,
898
+ affine_param_batch_decay=affine_param_batch_decay,
899
+ affine_param_codebook_decay=affine_param_codebook_decay,
900
+ )
901
+
902
+ self._codebook = codebook_class(**codebook_kwargs)
903
+ if eval:
904
+ self._codebook.eval()
905
+
906
+ self.in_place_codebook_optimizer = (
907
+ in_place_codebook_optimizer(self._codebook.parameters())
908
+ if exists(in_place_codebook_optimizer)
909
+ else None
910
+ )
911
+
912
+ self.codebook_size = codebook_size
913
+
914
+ self.accept_image_fmap = accept_image_fmap
915
+ self.channel_last = channel_last
916
+
917
+ @property
918
+ def codebook(self):
919
+ codebook = self._codebook.embed
920
+
921
+ if self.separate_codebook_per_head:
922
+ return codebook
923
+
924
+ return rearrange(codebook, "1 ... -> ...")
925
+
926
+ @codebook.setter
927
+ def codebook(self, codes):
928
+ if not self.separate_codebook_per_head:
929
+ codes = rearrange(codes, "... -> 1 ...")
930
+
931
+ self._codebook.embed.copy_(codes)
932
+
933
+ def get_codes_from_indices(self, indices):
934
+ codebook = self.codebook
935
+ is_multiheaded = codebook.ndim > 2
936
+
937
+ if not is_multiheaded:
938
+ codes = codebook[indices]
939
+ return rearrange(codes, "... h d -> ... (h d)")
940
+
941
+ indices, ps = pack_one(indices, "b * h")
942
+ indices = rearrange(indices, "b n h -> b h n")
943
+
944
+ indices = repeat(indices, "b h n -> b h n d", d=codebook.shape[-1])
945
+ codebook = repeat(codebook, "h n d -> b h n d", b=indices.shape[0])
946
+
947
+ codes = codebook.gather(2, indices)
948
+ codes = rearrange(codes, "b h n d -> b n (h d)")
949
+ codes = unpack_one(codes, ps, "b * d")
950
+ return codes
951
+
952
+ def forward(
953
+ self,
954
+ x,
955
+ indices=None,
956
+ mask=None,
957
+ sample_codebook_temp=None,
958
+ freeze_codebook=False,
959
+ ):
960
+ orig_input = x
961
+
962
+ only_one = x.ndim == 2
963
+
964
+ if only_one:
965
+ assert not exists(mask)
966
+ x = rearrange(x, "b d -> b 1 d")
967
+
968
+ shape, device, heads, is_multiheaded, codebook_size, return_loss = (
969
+ x.shape,
970
+ x.device,
971
+ self.heads,
972
+ self.heads > 1,
973
+ self.codebook_size,
974
+ exists(indices),
975
+ )
976
+
977
+ need_transpose = not self.channel_last and not self.accept_image_fmap
978
+ should_inplace_optimize = exists(self.in_place_codebook_optimizer)
979
+
980
+ # rearrange inputs
981
+
982
+ if self.accept_image_fmap:
983
+ height, width = x.shape[-2:]
984
+ x = rearrange(x, "b c h w -> b (h w) c")
985
+
986
+ if need_transpose:
987
+ x = rearrange(x, "b d n -> b n d")
988
+
989
+ # project input
990
+
991
+ x = self.project_in(x)
992
+
993
+ # handle multi-headed separate codebooks
994
+
995
+ if is_multiheaded:
996
+ ein_rhs_eq = "h b n d" if self.separate_codebook_per_head else "1 (b h) n d"
997
+ x = rearrange(x, f"b n (h d) -> {ein_rhs_eq}", h=heads)
998
+
999
+ # l2norm for cosine sim, otherwise identity
1000
+
1001
+ x = self._codebook.transform_input(x)
1002
+
1003
+ # codebook forward kwargs
1004
+
1005
+ codebook_forward_kwargs = dict(
1006
+ sample_codebook_temp=sample_codebook_temp,
1007
+ mask=mask,
1008
+ freeze_codebook=freeze_codebook,
1009
+ )
1010
+
1011
+ # quantize
1012
+
1013
+ quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs)
1014
+
1015
+ # one step in-place update
1016
+
1017
+ if should_inplace_optimize and self.training and not freeze_codebook:
1018
+ if exists(mask):
1019
+ loss = F.mse_loss(quantize, x.detach(), reduction="none")
1020
+
1021
+ loss_mask = mask
1022
+ if is_multiheaded:
1023
+ loss_mask = repeat(
1024
+ mask,
1025
+ "b n -> c (b h) n",
1026
+ c=loss.shape[0],
1027
+ h=loss.shape[1] // mask.shape[0],
1028
+ )
1029
+
1030
+ loss = loss[loss_mask].mean()
1031
+
1032
+ else:
1033
+ loss = F.mse_loss(quantize, x.detach())
1034
+
1035
+ loss.backward()
1036
+ self.in_place_codebook_optimizer.step()
1037
+ self.in_place_codebook_optimizer.zero_grad()
1038
+
1039
+ # quantize again
1040
+
1041
+ quantize, embed_ind, distances = self._codebook(
1042
+ x, **codebook_forward_kwargs
1043
+ )
1044
+
1045
+ if self.training:
1046
+ # determine code to use for commitment loss
1047
+ maybe_detach = (
1048
+ torch.detach
1049
+ if not self.learnable_codebook or freeze_codebook
1050
+ else identity
1051
+ )
1052
+
1053
+ commit_quantize = maybe_detach(quantize)
1054
+
1055
+ # straight through
1056
+
1057
+ quantize = x + (quantize - x).detach()
1058
+
1059
+ if self.sync_update_v > 0.0:
1060
+ # (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf
1061
+ quantize = quantize + self.sync_update_v * (
1062
+ quantize - quantize.detach()
1063
+ )
1064
+
1065
+ # function for calculating cross entropy loss to distance matrix
1066
+ # used for (1) naturalspeech2 training residual vq latents to be close to the correct codes and (2) cross-entropy based commitment loss
1067
+
1068
+ def calculate_ce_loss(codes):
1069
+ if not is_multiheaded:
1070
+ dist_einops_eq = "1 b n l -> b l n"
1071
+ elif self.separate_codebook_per_head:
1072
+ dist_einops_eq = "c b n l -> b l n c"
1073
+ else:
1074
+ dist_einops_eq = "1 (b h) n l -> b l n h"
1075
+
1076
+ ce_loss = F.cross_entropy(
1077
+ rearrange(distances, dist_einops_eq, b=shape[0]), codes, ignore_index=-1
1078
+ )
1079
+
1080
+ return ce_loss
1081
+
1082
+ # if returning cross entropy loss on codes that were passed in
1083
+
1084
+ if return_loss:
1085
+ return quantize, calculate_ce_loss(indices)
1086
+
1087
+ # transform embedding indices
1088
+
1089
+ if is_multiheaded:
1090
+ if self.separate_codebook_per_head:
1091
+ embed_ind = rearrange(embed_ind, "h b n -> b n h", h=heads)
1092
+ else:
1093
+ embed_ind = rearrange(embed_ind, "1 (b h) n -> b n h", h=heads)
1094
+
1095
+ if self.accept_image_fmap:
1096
+ embed_ind = rearrange(
1097
+ embed_ind, "b (h w) ... -> b h w ...", h=height, w=width
1098
+ )
1099
+
1100
+ if only_one:
1101
+ embed_ind = rearrange(embed_ind, "b 1 -> b")
1102
+
1103
+ # aggregate loss
1104
+
1105
+ loss = torch.tensor([0.0], device=device, requires_grad=self.training)
1106
+
1107
+ if self.training:
1108
+ if self.commitment_weight > 0:
1109
+ if self.commitment_use_cross_entropy_loss:
1110
+ if exists(mask):
1111
+ ce_loss_mask = mask
1112
+ if is_multiheaded:
1113
+ ce_loss_mask = repeat(ce_loss_mask, "b n -> b n h", h=heads)
1114
+
1115
+ embed_ind.masked_fill_(~ce_loss_mask, -1)
1116
+
1117
+ commit_loss = calculate_ce_loss(embed_ind)
1118
+ else:
1119
+ if exists(mask):
1120
+ # with variable lengthed sequences
1121
+ commit_loss = F.mse_loss(commit_quantize, x, reduction="none")
1122
+
1123
+ loss_mask = mask
1124
+ if is_multiheaded:
1125
+ loss_mask = repeat(
1126
+ loss_mask,
1127
+ "b n -> c (b h) n",
1128
+ c=commit_loss.shape[0],
1129
+ h=commit_loss.shape[1] // mask.shape[0],
1130
+ )
1131
+
1132
+ commit_loss = commit_loss[loss_mask].mean()
1133
+ else:
1134
+ commit_loss = F.mse_loss(commit_quantize, x)
1135
+
1136
+ loss = loss + commit_loss * self.commitment_weight
1137
+
1138
+ if self.has_codebook_orthogonal_loss:
1139
+ codebook = self._codebook.embed
1140
+
1141
+ # only calculate orthogonal loss for the activated codes for this batch
1142
+
1143
+ if self.orthogonal_reg_active_codes_only:
1144
+ assert not (
1145
+ is_multiheaded and self.separate_codebook_per_head
1146
+ ), "orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet"
1147
+ unique_code_ids = torch.unique(embed_ind)
1148
+ codebook = codebook[:, unique_code_ids]
1149
+
1150
+ num_codes = codebook.shape[-2]
1151
+
1152
+ if (
1153
+ exists(self.orthogonal_reg_max_codes)
1154
+ and num_codes > self.orthogonal_reg_max_codes
1155
+ ):
1156
+ rand_ids = torch.randperm(num_codes, device=device)[
1157
+ : self.orthogonal_reg_max_codes
1158
+ ]
1159
+ codebook = codebook[:, rand_ids]
1160
+
1161
+ orthogonal_reg_loss = orthogonal_loss_fn(codebook)
1162
+ loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
1163
+
1164
+ # handle multi-headed quantized embeddings
1165
+
1166
+ if is_multiheaded:
1167
+ if self.separate_codebook_per_head:
1168
+ quantize = rearrange(quantize, "h b n d -> b n (h d)", h=heads)
1169
+ else:
1170
+ quantize = rearrange(quantize, "1 (b h) n d -> b n (h d)", h=heads)
1171
+
1172
+ # project out
1173
+
1174
+ quantize = self.project_out(quantize)
1175
+
1176
+ # rearrange quantized embeddings
1177
+
1178
+ if need_transpose:
1179
+ quantize = rearrange(quantize, "b n d -> b d n")
1180
+
1181
+ if self.accept_image_fmap:
1182
+ quantize = rearrange(quantize, "b (h w) c -> b c h w", h=height, w=width)
1183
+
1184
+ if only_one:
1185
+ quantize = rearrange(quantize, "b 1 d -> b d")
1186
+
1187
+ # if masking, only return quantized for where mask has True
1188
+
1189
+ if exists(mask):
1190
+ quantize = torch.where(
1191
+ rearrange(mask, "... -> ... 1"), quantize, orig_input
1192
+ )
1193
+
1194
+ return quantize, embed_ind, loss