nextrec 0.4.5__py3-none-any.whl → 0.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,826 @@
1
+ """
2
+ Residual Quantized Variational AutoEncoder (RQ-VAE) for Generative Recommendation.
3
+
4
+ Date: created on 11/12/2025
5
+ Checkpoint: edit on 13/12/2025
6
+ Author: Yang Zhou, zyaztec@gmail.com
7
+ Source code reference:
8
+ [1] Tencent-Advertising-Algorithm-Competition-2025-Baseline
9
+ Reference:
10
+ [1] Lee et al. Autoregressive Image Generation using Residual Quantization. CVPR 2022.
11
+ [2] Zeghidour et al. SoundStream: An End-to-End Neural Audio Codec. IEEE/ACM TASLP 2021.
12
+
13
+ RQ-VAE learns hierarchical discrete representations via residual quantization.
14
+ It encodes continuous embeddings (e.g., item/user embeddings) into multi-level
15
+ semantic IDs, enabling downstream tasks like retrieval, classification, or generation.
16
+
17
+ Architecture:
18
+ (1) Encoder: Projects input embeddings to latent space
19
+ (2) Residual Quantizer (RQ): Multi-level vector quantization on residuals
20
+ (3) Decoder: Reconstructs original embeddings from quantized latents
21
+ (4) Training: Reconstruction loss + codebook/commitment loss
22
+
23
+ Key Features:
24
+ - Hierarchical semantic ID extraction for multi-level representations
25
+ - Flexible codebook initialization (K-Means, Balanced K-Means, Random)
26
+ - Balanced K-Means ensures uniform cluster distribution
27
+ - Supports shared or independent codebooks across quantization levels
28
+ - Cosine or L2 distance metrics for vector quantization
29
+
30
+ RQ-VAE 通过残差量化学习分层离散表示,将连续嵌入(如物品/用户嵌入,或者多模态嵌入)编码为
31
+ 多层次语义 ID,可用于检索、分类或生成等下游任务。
32
+
33
+ 架构:
34
+ (1) 编码器:将输入嵌入映射到潜在空间
35
+ (2) 残差量化器(RQ):对残差进行多级向量量化
36
+ (3) 解码器:从量化后的潜在向量重构原始嵌入
37
+ (4) 训练:重构损失 + 码本/承诺损失
38
+
39
+ 核心特性:
40
+ - 分层语义 ID 提取,实现多级别表示
41
+ - 灵活的码本初始化(K-Means、均衡 K-Means、随机)
42
+ - 均衡 K-Means 确保聚类分布均匀
43
+ - 支持跨量化层级的共享或独立码本
44
+ - 支持余弦或 L2 距离度量的向量量化
45
+ """
46
+
47
+ from __future__ import annotations
48
+
49
+ import math
50
+ import torch
51
+ import torch.nn as nn
52
+ import torch.nn.functional as F
53
+ from sklearn.cluster import KMeans
54
+ from typing import cast
55
+ import logging
56
+ import tqdm
57
+
58
+ from torch.utils.data import DataLoader
59
+
60
+ from nextrec.basic.features import DenseFeature
61
+ from nextrec.basic.model import BaseModel
62
+ from nextrec.data.batch_utils import batch_to_dict
63
+ from nextrec.basic.loggers import colorize, setup_logger
64
+
65
+
66
+ def kmeans(
67
+ data: torch.Tensor, n_clusters: int, kmeans_iters: int
68
+ ) -> tuple[torch.Tensor, torch.Tensor]:
69
+
70
+ km = KMeans(n_clusters=n_clusters, max_iter=kmeans_iters, n_init="auto")
71
+
72
+ data_cpu = data.detach().cpu()
73
+ np_data = data_cpu.numpy()
74
+ km.fit(np_data)
75
+ return torch.tensor(km.cluster_centers_), torch.tensor(km.labels_)
76
+
77
+
78
+ class BalancedKmeans(nn.Module):
79
+ """Balanced K-Means clustering implementation.
80
+ Ensures clusters have approximately equal number of samples.
81
+ """
82
+
83
+ def __init__(
84
+ self, num_clusters: int, kmeans_iters: int, tolerance: float, device: str
85
+ ):
86
+ super().__init__()
87
+ self.num_clusters = num_clusters
88
+ self.kmeans_iters = kmeans_iters
89
+ self.tolerance = tolerance
90
+ self.device = device
91
+ self.codebook: torch.Tensor | None = None # type: ignore
92
+
93
+ def compute_distances(self, data: torch.Tensor) -> torch.Tensor:
94
+ if self.codebook is None:
95
+ raise RuntimeError(
96
+ "Codebook is not initialized before computing distances."
97
+ )
98
+ return torch.cdist(data, self.codebook)
99
+
100
+ def assign_clusters(self, dist: torch.Tensor) -> torch.Tensor:
101
+ samples_cnt = dist.shape[0]
102
+ samples_labels = torch.empty(samples_cnt, dtype=torch.long, device=self.device)
103
+ clusters_cnt = torch.zeros(
104
+ self.num_clusters, dtype=torch.long, device=self.device
105
+ )
106
+
107
+ max_per_cluster = math.ceil(samples_cnt / self.num_clusters)
108
+
109
+ sorted_indices = torch.argsort(dist, dim=-1)
110
+
111
+ for i in range(samples_cnt):
112
+ assigned = False
113
+ for j in range(self.num_clusters):
114
+ cluster_idx = sorted_indices[i, j]
115
+ if clusters_cnt[cluster_idx] < max_per_cluster:
116
+ samples_labels[i] = cluster_idx
117
+ clusters_cnt[cluster_idx] += 1
118
+ assigned = True
119
+ break
120
+
121
+ if not assigned:
122
+ cluster_idx = torch.argmin(clusters_cnt)
123
+ samples_labels[i] = cluster_idx
124
+ clusters_cnt[cluster_idx] += 1
125
+
126
+ return samples_labels
127
+
128
+ def update_codebook(
129
+ self, data: torch.Tensor, samples_labels: torch.Tensor
130
+ ) -> torch.Tensor:
131
+ new_codebook = []
132
+ for i in range(self.num_clusters):
133
+ cluster_data = data[samples_labels == i]
134
+ if len(cluster_data) > 0:
135
+ new_codebook.append(cluster_data.mean(dim=0))
136
+ else:
137
+ assert self.codebook is not None
138
+ new_codebook.append(self.codebook[i])
139
+ return torch.stack(new_codebook)
140
+
141
+ def fit(self, data: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
142
+ num_emb, codebook_emb_dim = data.shape
143
+ data = data.to(self.device)
144
+
145
+ # initialize codebook with random samples
146
+ # If num_emb < num_clusters, sample with replacement
147
+ if num_emb >= self.num_clusters:
148
+ indices = torch.randperm(num_emb)[: self.num_clusters]
149
+ self.codebook = data[indices].clone()
150
+ else:
151
+ # Sample with replacement and add random noise
152
+ indices = torch.randint(0, num_emb, (self.num_clusters,))
153
+ self.codebook = data[indices].clone()
154
+ self.codebook += torch.randn_like(self.codebook) * 0.01
155
+
156
+ for _ in range(self.kmeans_iters):
157
+ dist = self.compute_distances(data)
158
+ samples_labels = self.assign_clusters(dist)
159
+ _new_codebook = self.update_codebook(data, samples_labels)
160
+ assert self.codebook is not None
161
+ if torch.norm(_new_codebook - self.codebook) < self.tolerance:
162
+ self.codebook = _new_codebook
163
+ break
164
+
165
+ self.codebook = _new_codebook
166
+
167
+ assert self.codebook is not None
168
+ return self.codebook, samples_labels
169
+
170
+ def predict(self, data: torch.Tensor) -> torch.Tensor:
171
+ data = data.to(self.device)
172
+ dist = self.compute_distances(data)
173
+ samples_labels = self.assign_clusters(dist)
174
+ return samples_labels
175
+
176
+
177
+ class RQEncoder(nn.Module):
178
+ """Encoder network for RQ-VAE."""
179
+
180
+ def __init__(self, input_dim: int, hidden_dims: list, latent_dim: int):
181
+ super().__init__()
182
+
183
+ self.stages = nn.ModuleList()
184
+ in_dim = input_dim
185
+
186
+ for out_dim in hidden_dims:
187
+ stage = nn.Sequential(
188
+ nn.Linear(in_dim, out_dim), nn.BatchNorm1d(out_dim), nn.ReLU()
189
+ )
190
+ self.stages.append(stage)
191
+ in_dim = out_dim
192
+
193
+ self.stages.append(nn.Linear(in_dim, latent_dim))
194
+
195
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
196
+ for stage in self.stages:
197
+ x = stage(x)
198
+ return x
199
+
200
+
201
+ class RQDecoder(nn.Module):
202
+ """Decoder network for RQ-VAE."""
203
+
204
+ def __init__(self, latent_dim: int, hidden_dims: list, output_dim: int):
205
+ super().__init__()
206
+
207
+ self.stages = nn.ModuleList()
208
+ in_dim = latent_dim
209
+
210
+ for out_dim in hidden_dims:
211
+ stage = nn.Sequential(
212
+ nn.Linear(in_dim, out_dim), nn.BatchNorm1d(out_dim), nn.ReLU()
213
+ )
214
+ self.stages.append(stage)
215
+ in_dim = out_dim
216
+
217
+ self.stages.append(nn.Linear(in_dim, output_dim))
218
+
219
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
220
+ for stage in self.stages:
221
+ x = stage(x)
222
+ return x
223
+
224
+
225
+ # Vector Quantization
226
+ class VQEmbedding(nn.Embedding):
227
+ """
228
+ Vector Quantization (VQ) embedding module used in VQ-VAE / RQ-VAE.
229
+
230
+ This module maintains a learnable codebook and maps continuous input
231
+ vectors to their nearest codebook entries.
232
+ It supports optional one-time codebook initialization using K-Means
233
+ to improve training stability.
234
+
235
+ - Codebook is initialized lazily on the first forward pass.
236
+ - Nearest-neighbor search is performed using either L2 or cosine distance.
237
+ - The module outputs both the quantized embeddings and their discrete
238
+ semantic IDs (codebook indices).
239
+
240
+ Typical input shape:
241
+ data: Tensor of shape [N, D], where D == codebook_emb_dim
242
+
243
+ Output:
244
+ q: Quantized embeddings of shape [N, D]
245
+ semantic_id: Discrete codebook indices of shape [N]
246
+ """
247
+
248
+ def __init__(
249
+ self,
250
+ num_clusters,
251
+ codebook_emb_dim: int,
252
+ kmeans_method: str,
253
+ kmeans_iters: int,
254
+ distances_method: str,
255
+ device: str,
256
+ ):
257
+ super(VQEmbedding, self).__init__(num_clusters, codebook_emb_dim)
258
+
259
+ self.num_clusters = num_clusters
260
+ self.codebook_emb_dim = codebook_emb_dim
261
+ self.kmeans_method = kmeans_method
262
+ self.kmeans_iters = kmeans_iters
263
+ self.distances_method = distances_method
264
+ self.device = device
265
+ self.codebook_initialized = False
266
+
267
+ def create_codebook(self, data: torch.Tensor) -> None:
268
+
269
+ if self.codebook_initialized:
270
+ return
271
+
272
+ if self.kmeans_method == "kmeans":
273
+ codebook, _ = kmeans(data, self.num_clusters, self.kmeans_iters)
274
+ elif self.kmeans_method == "bkmeans":
275
+ BKmeans = BalancedKmeans(
276
+ num_clusters=self.num_clusters,
277
+ kmeans_iters=self.kmeans_iters,
278
+ tolerance=1e-4,
279
+ device=self.device,
280
+ )
281
+ codebook, _ = BKmeans.fit(data)
282
+ else:
283
+ codebook = torch.randn(self.num_clusters, self.codebook_emb_dim)
284
+ codebook = codebook.to(self.device)
285
+ assert codebook.shape == (self.num_clusters, self.codebook_emb_dim)
286
+ self.weight.data.copy_(codebook)
287
+ self.codebook_initialized = True
288
+
289
+ @torch.no_grad()
290
+ def compute_distances(self, data: torch.Tensor) -> torch.Tensor:
291
+
292
+ codebook_t = self.weight.t()
293
+ assert codebook_t.shape == (self.codebook_emb_dim, self.num_clusters)
294
+ assert data.shape[-1] == self.codebook_emb_dim
295
+
296
+ if self.distances_method == "cosine":
297
+ data_norm = F.normalize(data, p=2, dim=-1)
298
+ _codebook_t_norm = F.normalize(codebook_t, p=2, dim=0)
299
+ distances = 1 - torch.mm(data_norm, _codebook_t_norm)
300
+ # l2
301
+ else:
302
+ data_norm_sq = data.pow(2).sum(dim=-1, keepdim=True)
303
+ codebook_t_norm_sq = codebook_t.pow(2).sum(dim=0, keepdim=True)
304
+ distances = torch.addmm(
305
+ data_norm_sq + codebook_t_norm_sq,
306
+ data,
307
+ codebook_t,
308
+ beta=1.0,
309
+ alpha=-2.0,
310
+ )
311
+ return distances
312
+
313
+ @torch.no_grad()
314
+ def create_semantic_id(self, data: torch.Tensor) -> torch.Tensor:
315
+
316
+ distances = self.compute_distances(data)
317
+ semantic_id = torch.argmin(distances, dim=-1)
318
+ return semantic_id
319
+
320
+ def update_emb(self, semantic_id: torch.Tensor) -> torch.Tensor:
321
+
322
+ update_emb = super().forward(semantic_id)
323
+ return update_emb
324
+
325
+ def forward(self, data: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
326
+
327
+ self.create_codebook(data)
328
+ semantic_id = self.create_semantic_id(data)
329
+ q = super().forward(semantic_id) # codebook lookup
330
+ return q, semantic_id
331
+
332
+
333
+ # Residual Quantizer
334
+ class RQ(nn.Module):
335
+ """
336
+ Residual Quantization (RQ) module for RQ-VAE.
337
+
338
+ This module performs multi-stage vector quantization on continuous latent
339
+ representations using a stack of VQEmbedding codebooks. Each codebook
340
+ quantizes the residual left by the previous quantization stage, enabling
341
+ fine-grained approximation with multiple small codebooks.
342
+
343
+ z_e (continuous latent)
344
+
345
+ q_1 = VQ_1(z_e)
346
+ r_2 = z_e - q_1
347
+
348
+ q_2 = VQ_2(r_2)
349
+ ...
350
+
351
+ z_q = q_1 + q_2 + ... + q_L
352
+
353
+ Input shape:
354
+ data: Tensor of shape [N, codebook_emb_dim]
355
+
356
+ Output:
357
+ zq_list: List of accumulated quantized tensors at each level
358
+ semantic_ids: Tensor of discrete codebook indices
359
+ rq_loss: Total RQ-VAE quantization loss
360
+ """
361
+
362
+ def __init__(
363
+ self,
364
+ num_codebooks: int,
365
+ codebook_size: list,
366
+ codebook_emb_dim: int,
367
+ shared_codebook: bool,
368
+ kmeans_method: str,
369
+ kmeans_iters: int,
370
+ distances_method: str,
371
+ loss_beta: float,
372
+ device: str,
373
+ ):
374
+ super().__init__()
375
+ self.num_codebooks = num_codebooks
376
+ self.codebook_size = codebook_size
377
+ assert len(self.codebook_size) == self.num_codebooks
378
+ self.codebook_emb_dim = codebook_emb_dim
379
+ self.shared_codebook = shared_codebook
380
+
381
+ self.kmeans_method = kmeans_method
382
+ self.kmeans_iters = kmeans_iters
383
+ self.distances_method = distances_method
384
+ self.loss_beta = loss_beta
385
+ self.device = device
386
+
387
+ if self.shared_codebook:
388
+ self.shared_vq = VQEmbedding(
389
+ self.codebook_size[0],
390
+ self.codebook_emb_dim,
391
+ self.kmeans_method,
392
+ self.kmeans_iters,
393
+ self.distances_method,
394
+ self.device,
395
+ )
396
+ self.vqmodules = None
397
+ else:
398
+ self.shared_vq = None
399
+ self.vqmodules = nn.ModuleList(
400
+ [
401
+ VQEmbedding(
402
+ self.codebook_size[idx],
403
+ self.codebook_emb_dim,
404
+ self.kmeans_method,
405
+ self.kmeans_iters,
406
+ self.distances_method,
407
+ self.device,
408
+ )
409
+ for idx in range(self.num_codebooks)
410
+ ]
411
+ )
412
+
413
+ # get vq module for specific level
414
+ def vq(self, level: int) -> VQEmbedding:
415
+ if self.shared_codebook:
416
+ assert self.shared_vq is not None
417
+ return self.shared_vq
418
+ assert self.vqmodules is not None
419
+ return self.vqmodules[level]
420
+
421
+ # residual quantization process, transforms continuous data to discrete codes
422
+ def quantize(self, data: torch.Tensor):
423
+ r = data
424
+ z_q = torch.zeros_like(data)
425
+
426
+ r_in_list, q_list, zq_list, semantic_id_list = [], [], [], []
427
+
428
+ for i in range(self.num_codebooks):
429
+ r_in = r # current residual
430
+ vq = self.vq(i) # get VQ module for current level
431
+ q, ids = vq(r_in) # q: quantized embedding, ids: semantic IDs
432
+
433
+ q_st = (
434
+ r_in + (q - r_in).detach()
435
+ ) # **IMPORTANT** straight-through estimator, stop grad on r_in side
436
+ z_q = z_q + q_st # accumulate quantized embeddings
437
+ r = r - q_st # update residual
438
+
439
+ r_in_list.append(r_in)
440
+ q_list.append(q)
441
+ zq_list.append(z_q)
442
+ semantic_id_list.append(ids.unsqueeze(-1))
443
+
444
+ semantic_ids = torch.cat(semantic_id_list, dim=-1)
445
+ # zq_list: list of accumulated quantized embeddings at each level
446
+ # r_in_list: list of residuals before quantization at each level
447
+ # q_list: list of quantized embeddings at each level
448
+ # semantic_ids: [N, num_codebooks] discrete codebook indices
449
+ return zq_list, r_in_list, q_list, semantic_ids
450
+
451
+ def rqvae_loss(
452
+ self, r_in_list: list[torch.Tensor], q_list: list[torch.Tensor]
453
+ ) -> torch.Tensor:
454
+ losses = []
455
+ for r_in, q in zip(r_in_list, q_list):
456
+ # codebook loss: move codebook towards encoder output (stop grad on encoder side)
457
+ codebook_loss = (q - r_in.detach()).pow(2.0).mean()
458
+
459
+ # commitment loss: encourage encoder outputs to commit to codebook (stop grad on codebook side)
460
+ commit_loss = (r_in - q.detach()).pow(2.0).mean()
461
+
462
+ losses.append(codebook_loss + self.loss_beta * commit_loss)
463
+
464
+ return torch.stack(losses).sum()
465
+
466
+ def forward(
467
+ self, data: torch.Tensor
468
+ ) -> tuple[list[torch.Tensor], torch.Tensor, torch.Tensor]:
469
+
470
+ zq_list, r_in_list, q_list, semantic_ids = self.quantize(data)
471
+ rq_loss = self.rqvae_loss(r_in_list, q_list)
472
+ return zq_list, semantic_ids, rq_loss
473
+
474
+
475
+ # RQ-VAE Model
476
+ class RQVAE(BaseModel):
477
+
478
+ @property
479
+ def model_name(self) -> str:
480
+ return "RQVAE"
481
+
482
+ @property
483
+ def default_task(self) -> str:
484
+ # task is unused for unsupervised training, keep a valid default for BaseModel
485
+ return "regression"
486
+
487
+ def __init__(
488
+ self,
489
+ input_dim: int,
490
+ hidden_dims: list,
491
+ latent_dim: int,
492
+ num_codebooks: int,
493
+ codebook_size: list,
494
+ shared_codebook: bool,
495
+ kmeans_method,
496
+ kmeans_iters,
497
+ distances_method,
498
+ loss_beta: float,
499
+ device: str,
500
+ dense_features: list[DenseFeature] | None = None,
501
+ target: str | list[str] | None = None,
502
+ **kwargs,
503
+ ):
504
+
505
+ self.input_dim = input_dim
506
+ self.latent_dim = latent_dim
507
+ self.num_codebooks = num_codebooks
508
+ self.codebook_size = codebook_size
509
+ self.loss_beta = loss_beta
510
+
511
+ super().__init__(
512
+ dense_features=dense_features,
513
+ sparse_features=None,
514
+ sequence_features=None,
515
+ target=target,
516
+ task=self.default_task,
517
+ device=device,
518
+ **kwargs,
519
+ )
520
+
521
+ self.encoder = RQEncoder(input_dim, hidden_dims, latent_dim).to(self.device)
522
+ self.decoder = RQDecoder(latent_dim, hidden_dims[::-1], input_dim).to(
523
+ self.device
524
+ )
525
+ self.rq = RQ(
526
+ num_codebooks,
527
+ codebook_size,
528
+ latent_dim,
529
+ shared_codebook,
530
+ kmeans_method,
531
+ kmeans_iters,
532
+ distances_method,
533
+ loss_beta,
534
+ self.device,
535
+ ).to(self.device)
536
+
537
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
538
+
539
+ return self.encoder(x)
540
+
541
+ def decode(self, z_vq: torch.Tensor | list[torch.Tensor]) -> torch.Tensor:
542
+
543
+ if isinstance(z_vq, list):
544
+ z_vq = z_vq[-1]
545
+ return self.decoder(z_vq)
546
+
547
+ def compute_loss(
548
+ self, x_hat: torch.Tensor, x_gt: torch.Tensor, rqvae_loss: torch.Tensor
549
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
550
+
551
+ recon_loss = F.mse_loss(x_hat, x_gt, reduction="mean")
552
+ total_loss = recon_loss + rqvae_loss
553
+ return recon_loss, rqvae_loss, total_loss
554
+
555
+ def prepare_loader(
556
+ self,
557
+ data: DataLoader | dict | list | tuple,
558
+ batch_size: int,
559
+ shuffle: bool,
560
+ num_workers: int,
561
+ ) -> DataLoader:
562
+ if isinstance(data, DataLoader):
563
+ return data
564
+ dataloader = self.prepare_data_loader(
565
+ data=data,
566
+ batch_size=batch_size,
567
+ shuffle=shuffle,
568
+ num_workers=num_workers,
569
+ )
570
+ if isinstance(dataloader, tuple):
571
+ loader = dataloader[0]
572
+ else:
573
+ loader = dataloader
574
+ return cast(DataLoader, loader)
575
+
576
+ # extract input embeddings from batch data
577
+ def extract_embeddings(self, batch_data) -> torch.Tensor:
578
+
579
+ batch_dict = batch_to_dict(batch_data)
580
+ X_input, _ = self.get_input(batch_dict, require_labels=False)
581
+
582
+ if not self.all_features:
583
+ raise ValueError(
584
+ "[RQVAE] dense_features are required to use fit/predict helpers."
585
+ )
586
+ tensors: list[torch.Tensor] = []
587
+ for name in self.feature_names:
588
+ if name not in X_input:
589
+ raise KeyError(
590
+ f"[RQVAE] Feature '{name}' not found in input batch. Available keys: {list(X_input.keys())}"
591
+ )
592
+ tensors.append(X_input[name].to(self.device).float())
593
+ if not tensors:
594
+ raise ValueError("[RQVAE] No feature tensors found in batch.")
595
+ init_embedding = tensors[0] if len(tensors) == 1 else torch.cat(tensors, dim=-1)
596
+ if init_embedding.shape[-1] != self.input_dim:
597
+ raise ValueError(
598
+ f"[RQVAE] Input dim mismatch: expected {self.input_dim}, got {init_embedding.shape[-1]}."
599
+ )
600
+
601
+ return init_embedding
602
+
603
+ def init_codebook(self, train_loader: DataLoader, init_batches: int) -> None:
604
+ cached: list[torch.Tensor] = []
605
+ for batch_idx, batch in enumerate(train_loader):
606
+ cached.append(self.extract_embeddings(batch))
607
+ if batch_idx >= init_batches - 1:
608
+ break
609
+ if not cached:
610
+ raise ValueError("[RQVAE] No data available for codebook initialization.")
611
+
612
+ init_data = torch.cat(cached, dim=0)
613
+
614
+ with torch.no_grad():
615
+ # Encode to latent space, [num_samples, latent_dim]
616
+ z_e = self.encode(init_data)
617
+
618
+ r = z_e # current residual
619
+
620
+ for level in range(self.num_codebooks):
621
+ vq = self.rq.vq(level)
622
+ if not vq.codebook_initialized:
623
+ vq.create_codebook(r)
624
+ q, _ = vq(r) # quantize current residual
625
+ r = r - q # update residual
626
+
627
+ def get_semantic_ids(self, x_gt: torch.Tensor) -> torch.Tensor:
628
+ z_e = self.encode(x_gt) # encode source input to latent space
629
+ vq_emb_list, semantic_id_list, rqvae_loss = self.rq(z_e)
630
+ return semantic_id_list
631
+
632
+ def forward(
633
+ self, x_gt: torch.Tensor
634
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
635
+ z_e = self.encode(x_gt)
636
+ vq_emb_list, semantic_id_list, rqvae_loss = self.rq(z_e)
637
+ x_hat = self.decode(vq_emb_list)
638
+ recon_loss, rqvae_loss, total_loss = self.compute_loss(x_hat, x_gt, rqvae_loss)
639
+ return x_hat, semantic_id_list, recon_loss, rqvae_loss, total_loss
640
+
641
+ def fit(
642
+ self,
643
+ train_data: DataLoader | dict | list | tuple,
644
+ valid_data: DataLoader | dict | list | tuple | None = None,
645
+ epochs: int = 1,
646
+ batch_size: int = 256,
647
+ shuffle: bool = True,
648
+ num_workers: int = 0,
649
+ lr: float = 1e-3,
650
+ init_batches: int = 3,
651
+ ):
652
+ """
653
+ Train RQ-VAE.
654
+
655
+ Args:
656
+ train_data: Training data (DataLoader, dict, or array-like) that matches dense_features.
657
+ valid_data: Optional validation data for monitoring loss.
658
+ epochs: Training epochs.
659
+ batch_size: Batch size for building DataLoader when raw data is provided.
660
+ shuffle: Shuffle training data when constructing a DataLoader.
661
+ num_workers: Number of DataLoader workers.
662
+ lr: Learning rate for Adam optimizer.
663
+ init_batches: Number of batches used to initialize the codebook.
664
+ """
665
+ train_loader = self.prepare_loader(
666
+ data=train_data,
667
+ batch_size=batch_size,
668
+ shuffle=shuffle,
669
+ num_workers=num_workers,
670
+ )
671
+ valid_loader = (
672
+ self.prepare_loader(
673
+ data=valid_data,
674
+ batch_size=batch_size,
675
+ shuffle=False,
676
+ num_workers=num_workers,
677
+ )
678
+ if valid_data is not None
679
+ else None
680
+ )
681
+
682
+ if not self.logger_initialized and self.is_main_process:
683
+ setup_logger(session_id=self.session_id)
684
+ self.logger_initialized = True
685
+
686
+ # Minimal placeholders to satisfy BaseModel.summary when running unsupervised
687
+ if not hasattr(self, "metrics"):
688
+ self.metrics = ["loss"]
689
+ if not hasattr(self, "task_specific_metrics"):
690
+ self.task_specific_metrics = {}
691
+ if not hasattr(self, "best_metrics_mode"):
692
+ self.best_metrics_mode = "min"
693
+
694
+ self.to(self.device)
695
+ optimizer = torch.optim.Adam(self.parameters(), lr=lr)
696
+ self.train()
697
+
698
+ try:
699
+ steps_per_epoch = len(train_loader)
700
+ is_streaming = False
701
+ except TypeError:
702
+ steps_per_epoch = None
703
+ is_streaming = True
704
+
705
+ self.init_codebook(train_loader, init_batches=init_batches)
706
+
707
+ if self.is_main_process:
708
+ self.summary()
709
+ logging.info("")
710
+ logging.info(colorize("=" * 80, bold=True))
711
+ logging.info(
712
+ colorize(
713
+ "Start streaming training" if is_streaming else "Start training",
714
+ bold=True,
715
+ )
716
+ )
717
+ logging.info(colorize("=" * 80, bold=True))
718
+ logging.info("")
719
+ logging.info(colorize(f"Model device: {self.device}", bold=True))
720
+
721
+ for epoch in range(epochs):
722
+ total_loss = 0.0
723
+ step_count = 0
724
+ if is_streaming and self.is_main_process:
725
+ logging.info("")
726
+ logging.info(colorize(f"Epoch {epoch + 1}/{epochs}", bold=True))
727
+ if is_streaming:
728
+ batch_iter = enumerate(train_loader)
729
+ else:
730
+ tqdm_disable = not self.is_main_process
731
+ batch_iter = enumerate(
732
+ tqdm.tqdm(
733
+ train_loader,
734
+ desc=f"Epoch {epoch + 1}/{epochs}",
735
+ total=steps_per_epoch,
736
+ disable=tqdm_disable,
737
+ )
738
+ )
739
+ for _, batch in batch_iter:
740
+ embeddings = self.extract_embeddings(batch)
741
+ _, _, recon_loss, rqvae_loss, total_batch_loss = self(embeddings)
742
+
743
+ optimizer.zero_grad()
744
+ total_batch_loss.backward()
745
+ optimizer.step()
746
+
747
+ total_loss += total_batch_loss.item()
748
+ step_count += 1
749
+
750
+ denom = steps_per_epoch if steps_per_epoch is not None else step_count
751
+ avg_loss = total_loss / max(1, denom)
752
+ train_log = f"Epoch {epoch + 1}/{epochs} - Train Loss: {avg_loss:.4f}"
753
+
754
+ if valid_loader is not None:
755
+ val_total = 0.0
756
+ val_steps = 0
757
+ with torch.no_grad():
758
+ for batch in valid_loader:
759
+ embeddings = self.extract_embeddings(batch)
760
+ _, _, _, _, val_loss = self(embeddings)
761
+ val_total += val_loss.item()
762
+ val_steps += 1
763
+ try:
764
+ val_denom = len(valid_loader)
765
+ except TypeError:
766
+ val_denom = val_steps
767
+ val_avg = val_total / max(1, val_denom)
768
+ if self.is_main_process:
769
+ logging.info(colorize(train_log))
770
+ logging.info(
771
+ colorize(
772
+ f" Epoch {epoch + 1}/{epochs} - Valid Loss: {val_avg:.4f}",
773
+ color="cyan",
774
+ )
775
+ )
776
+ elif self.is_main_process:
777
+ logging.info(colorize(train_log))
778
+
779
+ if self.is_main_process:
780
+ logging.info(" ")
781
+ logging.info(colorize("Training finished.", bold=True))
782
+ logging.info(" ")
783
+ return self
784
+
785
+ def predict(
786
+ self,
787
+ data: DataLoader | dict | list | tuple,
788
+ batch_size: int = 256,
789
+ num_workers: int = 0,
790
+ return_reconstruction: bool = False,
791
+ as_numpy: bool = True,
792
+ ) -> torch.Tensor:
793
+ """
794
+ Generate semantic IDs or reconstructed embeddings.
795
+
796
+ Args:
797
+ data: Input data aligned with dense_features.
798
+ batch_size: Batch size for building DataLoader when raw data is provided.
799
+ num_workers: Number of DataLoader workers.
800
+ return_reconstruction: If True, return reconstructed embeddings; otherwise, return semantic IDs.
801
+ as_numpy: Whether to return a NumPy array; if False, returns a torch.Tensor on CPU.
802
+ """
803
+ data_loader = self.prepare_loader(
804
+ data=data,
805
+ batch_size=batch_size,
806
+ shuffle=False,
807
+ num_workers=num_workers,
808
+ )
809
+ outputs: list[torch.Tensor] = []
810
+ self.eval()
811
+ with torch.no_grad():
812
+ for batch in data_loader:
813
+ embeddings = self.extract_embeddings(batch)
814
+ if return_reconstruction:
815
+ x_hat, _, _, _, _ = self(embeddings)
816
+ outputs.append(x_hat.detach().cpu())
817
+ else:
818
+ semantic_ids = self.get_semantic_ids(embeddings)
819
+ outputs.append(semantic_ids.detach().cpu())
820
+
821
+ if outputs:
822
+ result = torch.cat(outputs, dim=0)
823
+ else:
824
+ out_dim = self.input_dim if return_reconstruction else self.num_codebooks
825
+ result = torch.empty((0, out_dim))
826
+ return result.numpy() if as_numpy else result