nextrec 0.4.5__py3-none-any.whl → 0.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/callback.py +399 -21
- nextrec/basic/features.py +4 -0
- nextrec/basic/layers.py +103 -24
- nextrec/basic/metrics.py +71 -1
- nextrec/basic/model.py +285 -186
- nextrec/data/data_processing.py +1 -3
- nextrec/loss/loss_utils.py +73 -4
- nextrec/models/generative/__init__.py +16 -0
- nextrec/models/generative/hstu.py +110 -57
- nextrec/models/generative/rqvae.py +826 -0
- nextrec/models/match/dssm.py +5 -4
- nextrec/models/match/dssm_v2.py +4 -3
- nextrec/models/match/mind.py +5 -4
- nextrec/models/match/sdm.py +5 -4
- nextrec/models/match/youtube_dnn.py +5 -4
- nextrec/models/ranking/masknet.py +1 -1
- nextrec/utils/config.py +38 -1
- nextrec/utils/embedding.py +28 -0
- nextrec/utils/initializer.py +4 -4
- nextrec/utils/synthetic_data.py +19 -0
- nextrec-0.4.7.dist-info/METADATA +376 -0
- {nextrec-0.4.5.dist-info → nextrec-0.4.7.dist-info}/RECORD +26 -25
- nextrec-0.4.5.dist-info/METADATA +0 -357
- {nextrec-0.4.5.dist-info → nextrec-0.4.7.dist-info}/WHEEL +0 -0
- {nextrec-0.4.5.dist-info → nextrec-0.4.7.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.5.dist-info → nextrec-0.4.7.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,826 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Residual Quantized Variational AutoEncoder (RQ-VAE) for Generative Recommendation.
|
|
3
|
+
|
|
4
|
+
Date: created on 11/12/2025
|
|
5
|
+
Checkpoint: edit on 13/12/2025
|
|
6
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
|
+
Source code reference:
|
|
8
|
+
[1] Tencent-Advertising-Algorithm-Competition-2025-Baseline
|
|
9
|
+
Reference:
|
|
10
|
+
[1] Lee et al. Autoregressive Image Generation using Residual Quantization. CVPR 2022.
|
|
11
|
+
[2] Zeghidour et al. SoundStream: An End-to-End Neural Audio Codec. IEEE/ACM TASLP 2021.
|
|
12
|
+
|
|
13
|
+
RQ-VAE learns hierarchical discrete representations via residual quantization.
|
|
14
|
+
It encodes continuous embeddings (e.g., item/user embeddings) into multi-level
|
|
15
|
+
semantic IDs, enabling downstream tasks like retrieval, classification, or generation.
|
|
16
|
+
|
|
17
|
+
Architecture:
|
|
18
|
+
(1) Encoder: Projects input embeddings to latent space
|
|
19
|
+
(2) Residual Quantizer (RQ): Multi-level vector quantization on residuals
|
|
20
|
+
(3) Decoder: Reconstructs original embeddings from quantized latents
|
|
21
|
+
(4) Training: Reconstruction loss + codebook/commitment loss
|
|
22
|
+
|
|
23
|
+
Key Features:
|
|
24
|
+
- Hierarchical semantic ID extraction for multi-level representations
|
|
25
|
+
- Flexible codebook initialization (K-Means, Balanced K-Means, Random)
|
|
26
|
+
- Balanced K-Means ensures uniform cluster distribution
|
|
27
|
+
- Supports shared or independent codebooks across quantization levels
|
|
28
|
+
- Cosine or L2 distance metrics for vector quantization
|
|
29
|
+
|
|
30
|
+
RQ-VAE 通过残差量化学习分层离散表示,将连续嵌入(如物品/用户嵌入,或者多模态嵌入)编码为
|
|
31
|
+
多层次语义 ID,可用于检索、分类或生成等下游任务。
|
|
32
|
+
|
|
33
|
+
架构:
|
|
34
|
+
(1) 编码器:将输入嵌入映射到潜在空间
|
|
35
|
+
(2) 残差量化器(RQ):对残差进行多级向量量化
|
|
36
|
+
(3) 解码器:从量化后的潜在向量重构原始嵌入
|
|
37
|
+
(4) 训练:重构损失 + 码本/承诺损失
|
|
38
|
+
|
|
39
|
+
核心特性:
|
|
40
|
+
- 分层语义 ID 提取,实现多级别表示
|
|
41
|
+
- 灵活的码本初始化(K-Means、均衡 K-Means、随机)
|
|
42
|
+
- 均衡 K-Means 确保聚类分布均匀
|
|
43
|
+
- 支持跨量化层级的共享或独立码本
|
|
44
|
+
- 支持余弦或 L2 距离度量的向量量化
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
from __future__ import annotations
|
|
48
|
+
|
|
49
|
+
import math
|
|
50
|
+
import torch
|
|
51
|
+
import torch.nn as nn
|
|
52
|
+
import torch.nn.functional as F
|
|
53
|
+
from sklearn.cluster import KMeans
|
|
54
|
+
from typing import cast
|
|
55
|
+
import logging
|
|
56
|
+
import tqdm
|
|
57
|
+
|
|
58
|
+
from torch.utils.data import DataLoader
|
|
59
|
+
|
|
60
|
+
from nextrec.basic.features import DenseFeature
|
|
61
|
+
from nextrec.basic.model import BaseModel
|
|
62
|
+
from nextrec.data.batch_utils import batch_to_dict
|
|
63
|
+
from nextrec.basic.loggers import colorize, setup_logger
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def kmeans(
|
|
67
|
+
data: torch.Tensor, n_clusters: int, kmeans_iters: int
|
|
68
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
69
|
+
|
|
70
|
+
km = KMeans(n_clusters=n_clusters, max_iter=kmeans_iters, n_init="auto")
|
|
71
|
+
|
|
72
|
+
data_cpu = data.detach().cpu()
|
|
73
|
+
np_data = data_cpu.numpy()
|
|
74
|
+
km.fit(np_data)
|
|
75
|
+
return torch.tensor(km.cluster_centers_), torch.tensor(km.labels_)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class BalancedKmeans(nn.Module):
|
|
79
|
+
"""Balanced K-Means clustering implementation.
|
|
80
|
+
Ensures clusters have approximately equal number of samples.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def __init__(
|
|
84
|
+
self, num_clusters: int, kmeans_iters: int, tolerance: float, device: str
|
|
85
|
+
):
|
|
86
|
+
super().__init__()
|
|
87
|
+
self.num_clusters = num_clusters
|
|
88
|
+
self.kmeans_iters = kmeans_iters
|
|
89
|
+
self.tolerance = tolerance
|
|
90
|
+
self.device = device
|
|
91
|
+
self.codebook: torch.Tensor | None = None # type: ignore
|
|
92
|
+
|
|
93
|
+
def compute_distances(self, data: torch.Tensor) -> torch.Tensor:
|
|
94
|
+
if self.codebook is None:
|
|
95
|
+
raise RuntimeError(
|
|
96
|
+
"Codebook is not initialized before computing distances."
|
|
97
|
+
)
|
|
98
|
+
return torch.cdist(data, self.codebook)
|
|
99
|
+
|
|
100
|
+
def assign_clusters(self, dist: torch.Tensor) -> torch.Tensor:
|
|
101
|
+
samples_cnt = dist.shape[0]
|
|
102
|
+
samples_labels = torch.empty(samples_cnt, dtype=torch.long, device=self.device)
|
|
103
|
+
clusters_cnt = torch.zeros(
|
|
104
|
+
self.num_clusters, dtype=torch.long, device=self.device
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
max_per_cluster = math.ceil(samples_cnt / self.num_clusters)
|
|
108
|
+
|
|
109
|
+
sorted_indices = torch.argsort(dist, dim=-1)
|
|
110
|
+
|
|
111
|
+
for i in range(samples_cnt):
|
|
112
|
+
assigned = False
|
|
113
|
+
for j in range(self.num_clusters):
|
|
114
|
+
cluster_idx = sorted_indices[i, j]
|
|
115
|
+
if clusters_cnt[cluster_idx] < max_per_cluster:
|
|
116
|
+
samples_labels[i] = cluster_idx
|
|
117
|
+
clusters_cnt[cluster_idx] += 1
|
|
118
|
+
assigned = True
|
|
119
|
+
break
|
|
120
|
+
|
|
121
|
+
if not assigned:
|
|
122
|
+
cluster_idx = torch.argmin(clusters_cnt)
|
|
123
|
+
samples_labels[i] = cluster_idx
|
|
124
|
+
clusters_cnt[cluster_idx] += 1
|
|
125
|
+
|
|
126
|
+
return samples_labels
|
|
127
|
+
|
|
128
|
+
def update_codebook(
|
|
129
|
+
self, data: torch.Tensor, samples_labels: torch.Tensor
|
|
130
|
+
) -> torch.Tensor:
|
|
131
|
+
new_codebook = []
|
|
132
|
+
for i in range(self.num_clusters):
|
|
133
|
+
cluster_data = data[samples_labels == i]
|
|
134
|
+
if len(cluster_data) > 0:
|
|
135
|
+
new_codebook.append(cluster_data.mean(dim=0))
|
|
136
|
+
else:
|
|
137
|
+
assert self.codebook is not None
|
|
138
|
+
new_codebook.append(self.codebook[i])
|
|
139
|
+
return torch.stack(new_codebook)
|
|
140
|
+
|
|
141
|
+
def fit(self, data: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
142
|
+
num_emb, codebook_emb_dim = data.shape
|
|
143
|
+
data = data.to(self.device)
|
|
144
|
+
|
|
145
|
+
# initialize codebook with random samples
|
|
146
|
+
# If num_emb < num_clusters, sample with replacement
|
|
147
|
+
if num_emb >= self.num_clusters:
|
|
148
|
+
indices = torch.randperm(num_emb)[: self.num_clusters]
|
|
149
|
+
self.codebook = data[indices].clone()
|
|
150
|
+
else:
|
|
151
|
+
# Sample with replacement and add random noise
|
|
152
|
+
indices = torch.randint(0, num_emb, (self.num_clusters,))
|
|
153
|
+
self.codebook = data[indices].clone()
|
|
154
|
+
self.codebook += torch.randn_like(self.codebook) * 0.01
|
|
155
|
+
|
|
156
|
+
for _ in range(self.kmeans_iters):
|
|
157
|
+
dist = self.compute_distances(data)
|
|
158
|
+
samples_labels = self.assign_clusters(dist)
|
|
159
|
+
_new_codebook = self.update_codebook(data, samples_labels)
|
|
160
|
+
assert self.codebook is not None
|
|
161
|
+
if torch.norm(_new_codebook - self.codebook) < self.tolerance:
|
|
162
|
+
self.codebook = _new_codebook
|
|
163
|
+
break
|
|
164
|
+
|
|
165
|
+
self.codebook = _new_codebook
|
|
166
|
+
|
|
167
|
+
assert self.codebook is not None
|
|
168
|
+
return self.codebook, samples_labels
|
|
169
|
+
|
|
170
|
+
def predict(self, data: torch.Tensor) -> torch.Tensor:
|
|
171
|
+
data = data.to(self.device)
|
|
172
|
+
dist = self.compute_distances(data)
|
|
173
|
+
samples_labels = self.assign_clusters(dist)
|
|
174
|
+
return samples_labels
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class RQEncoder(nn.Module):
|
|
178
|
+
"""Encoder network for RQ-VAE."""
|
|
179
|
+
|
|
180
|
+
def __init__(self, input_dim: int, hidden_dims: list, latent_dim: int):
|
|
181
|
+
super().__init__()
|
|
182
|
+
|
|
183
|
+
self.stages = nn.ModuleList()
|
|
184
|
+
in_dim = input_dim
|
|
185
|
+
|
|
186
|
+
for out_dim in hidden_dims:
|
|
187
|
+
stage = nn.Sequential(
|
|
188
|
+
nn.Linear(in_dim, out_dim), nn.BatchNorm1d(out_dim), nn.ReLU()
|
|
189
|
+
)
|
|
190
|
+
self.stages.append(stage)
|
|
191
|
+
in_dim = out_dim
|
|
192
|
+
|
|
193
|
+
self.stages.append(nn.Linear(in_dim, latent_dim))
|
|
194
|
+
|
|
195
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
196
|
+
for stage in self.stages:
|
|
197
|
+
x = stage(x)
|
|
198
|
+
return x
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
class RQDecoder(nn.Module):
|
|
202
|
+
"""Decoder network for RQ-VAE."""
|
|
203
|
+
|
|
204
|
+
def __init__(self, latent_dim: int, hidden_dims: list, output_dim: int):
|
|
205
|
+
super().__init__()
|
|
206
|
+
|
|
207
|
+
self.stages = nn.ModuleList()
|
|
208
|
+
in_dim = latent_dim
|
|
209
|
+
|
|
210
|
+
for out_dim in hidden_dims:
|
|
211
|
+
stage = nn.Sequential(
|
|
212
|
+
nn.Linear(in_dim, out_dim), nn.BatchNorm1d(out_dim), nn.ReLU()
|
|
213
|
+
)
|
|
214
|
+
self.stages.append(stage)
|
|
215
|
+
in_dim = out_dim
|
|
216
|
+
|
|
217
|
+
self.stages.append(nn.Linear(in_dim, output_dim))
|
|
218
|
+
|
|
219
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
220
|
+
for stage in self.stages:
|
|
221
|
+
x = stage(x)
|
|
222
|
+
return x
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
# Vector Quantization
|
|
226
|
+
class VQEmbedding(nn.Embedding):
|
|
227
|
+
"""
|
|
228
|
+
Vector Quantization (VQ) embedding module used in VQ-VAE / RQ-VAE.
|
|
229
|
+
|
|
230
|
+
This module maintains a learnable codebook and maps continuous input
|
|
231
|
+
vectors to their nearest codebook entries.
|
|
232
|
+
It supports optional one-time codebook initialization using K-Means
|
|
233
|
+
to improve training stability.
|
|
234
|
+
|
|
235
|
+
- Codebook is initialized lazily on the first forward pass.
|
|
236
|
+
- Nearest-neighbor search is performed using either L2 or cosine distance.
|
|
237
|
+
- The module outputs both the quantized embeddings and their discrete
|
|
238
|
+
semantic IDs (codebook indices).
|
|
239
|
+
|
|
240
|
+
Typical input shape:
|
|
241
|
+
data: Tensor of shape [N, D], where D == codebook_emb_dim
|
|
242
|
+
|
|
243
|
+
Output:
|
|
244
|
+
q: Quantized embeddings of shape [N, D]
|
|
245
|
+
semantic_id: Discrete codebook indices of shape [N]
|
|
246
|
+
"""
|
|
247
|
+
|
|
248
|
+
def __init__(
|
|
249
|
+
self,
|
|
250
|
+
num_clusters,
|
|
251
|
+
codebook_emb_dim: int,
|
|
252
|
+
kmeans_method: str,
|
|
253
|
+
kmeans_iters: int,
|
|
254
|
+
distances_method: str,
|
|
255
|
+
device: str,
|
|
256
|
+
):
|
|
257
|
+
super(VQEmbedding, self).__init__(num_clusters, codebook_emb_dim)
|
|
258
|
+
|
|
259
|
+
self.num_clusters = num_clusters
|
|
260
|
+
self.codebook_emb_dim = codebook_emb_dim
|
|
261
|
+
self.kmeans_method = kmeans_method
|
|
262
|
+
self.kmeans_iters = kmeans_iters
|
|
263
|
+
self.distances_method = distances_method
|
|
264
|
+
self.device = device
|
|
265
|
+
self.codebook_initialized = False
|
|
266
|
+
|
|
267
|
+
def create_codebook(self, data: torch.Tensor) -> None:
|
|
268
|
+
|
|
269
|
+
if self.codebook_initialized:
|
|
270
|
+
return
|
|
271
|
+
|
|
272
|
+
if self.kmeans_method == "kmeans":
|
|
273
|
+
codebook, _ = kmeans(data, self.num_clusters, self.kmeans_iters)
|
|
274
|
+
elif self.kmeans_method == "bkmeans":
|
|
275
|
+
BKmeans = BalancedKmeans(
|
|
276
|
+
num_clusters=self.num_clusters,
|
|
277
|
+
kmeans_iters=self.kmeans_iters,
|
|
278
|
+
tolerance=1e-4,
|
|
279
|
+
device=self.device,
|
|
280
|
+
)
|
|
281
|
+
codebook, _ = BKmeans.fit(data)
|
|
282
|
+
else:
|
|
283
|
+
codebook = torch.randn(self.num_clusters, self.codebook_emb_dim)
|
|
284
|
+
codebook = codebook.to(self.device)
|
|
285
|
+
assert codebook.shape == (self.num_clusters, self.codebook_emb_dim)
|
|
286
|
+
self.weight.data.copy_(codebook)
|
|
287
|
+
self.codebook_initialized = True
|
|
288
|
+
|
|
289
|
+
@torch.no_grad()
|
|
290
|
+
def compute_distances(self, data: torch.Tensor) -> torch.Tensor:
|
|
291
|
+
|
|
292
|
+
codebook_t = self.weight.t()
|
|
293
|
+
assert codebook_t.shape == (self.codebook_emb_dim, self.num_clusters)
|
|
294
|
+
assert data.shape[-1] == self.codebook_emb_dim
|
|
295
|
+
|
|
296
|
+
if self.distances_method == "cosine":
|
|
297
|
+
data_norm = F.normalize(data, p=2, dim=-1)
|
|
298
|
+
_codebook_t_norm = F.normalize(codebook_t, p=2, dim=0)
|
|
299
|
+
distances = 1 - torch.mm(data_norm, _codebook_t_norm)
|
|
300
|
+
# l2
|
|
301
|
+
else:
|
|
302
|
+
data_norm_sq = data.pow(2).sum(dim=-1, keepdim=True)
|
|
303
|
+
codebook_t_norm_sq = codebook_t.pow(2).sum(dim=0, keepdim=True)
|
|
304
|
+
distances = torch.addmm(
|
|
305
|
+
data_norm_sq + codebook_t_norm_sq,
|
|
306
|
+
data,
|
|
307
|
+
codebook_t,
|
|
308
|
+
beta=1.0,
|
|
309
|
+
alpha=-2.0,
|
|
310
|
+
)
|
|
311
|
+
return distances
|
|
312
|
+
|
|
313
|
+
@torch.no_grad()
|
|
314
|
+
def create_semantic_id(self, data: torch.Tensor) -> torch.Tensor:
|
|
315
|
+
|
|
316
|
+
distances = self.compute_distances(data)
|
|
317
|
+
semantic_id = torch.argmin(distances, dim=-1)
|
|
318
|
+
return semantic_id
|
|
319
|
+
|
|
320
|
+
def update_emb(self, semantic_id: torch.Tensor) -> torch.Tensor:
|
|
321
|
+
|
|
322
|
+
update_emb = super().forward(semantic_id)
|
|
323
|
+
return update_emb
|
|
324
|
+
|
|
325
|
+
def forward(self, data: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
326
|
+
|
|
327
|
+
self.create_codebook(data)
|
|
328
|
+
semantic_id = self.create_semantic_id(data)
|
|
329
|
+
q = super().forward(semantic_id) # codebook lookup
|
|
330
|
+
return q, semantic_id
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
# Residual Quantizer
|
|
334
|
+
class RQ(nn.Module):
|
|
335
|
+
"""
|
|
336
|
+
Residual Quantization (RQ) module for RQ-VAE.
|
|
337
|
+
|
|
338
|
+
This module performs multi-stage vector quantization on continuous latent
|
|
339
|
+
representations using a stack of VQEmbedding codebooks. Each codebook
|
|
340
|
+
quantizes the residual left by the previous quantization stage, enabling
|
|
341
|
+
fine-grained approximation with multiple small codebooks.
|
|
342
|
+
|
|
343
|
+
z_e (continuous latent)
|
|
344
|
+
↓
|
|
345
|
+
q_1 = VQ_1(z_e)
|
|
346
|
+
r_2 = z_e - q_1
|
|
347
|
+
↓
|
|
348
|
+
q_2 = VQ_2(r_2)
|
|
349
|
+
...
|
|
350
|
+
↓
|
|
351
|
+
z_q = q_1 + q_2 + ... + q_L
|
|
352
|
+
|
|
353
|
+
Input shape:
|
|
354
|
+
data: Tensor of shape [N, codebook_emb_dim]
|
|
355
|
+
|
|
356
|
+
Output:
|
|
357
|
+
zq_list: List of accumulated quantized tensors at each level
|
|
358
|
+
semantic_ids: Tensor of discrete codebook indices
|
|
359
|
+
rq_loss: Total RQ-VAE quantization loss
|
|
360
|
+
"""
|
|
361
|
+
|
|
362
|
+
def __init__(
|
|
363
|
+
self,
|
|
364
|
+
num_codebooks: int,
|
|
365
|
+
codebook_size: list,
|
|
366
|
+
codebook_emb_dim: int,
|
|
367
|
+
shared_codebook: bool,
|
|
368
|
+
kmeans_method: str,
|
|
369
|
+
kmeans_iters: int,
|
|
370
|
+
distances_method: str,
|
|
371
|
+
loss_beta: float,
|
|
372
|
+
device: str,
|
|
373
|
+
):
|
|
374
|
+
super().__init__()
|
|
375
|
+
self.num_codebooks = num_codebooks
|
|
376
|
+
self.codebook_size = codebook_size
|
|
377
|
+
assert len(self.codebook_size) == self.num_codebooks
|
|
378
|
+
self.codebook_emb_dim = codebook_emb_dim
|
|
379
|
+
self.shared_codebook = shared_codebook
|
|
380
|
+
|
|
381
|
+
self.kmeans_method = kmeans_method
|
|
382
|
+
self.kmeans_iters = kmeans_iters
|
|
383
|
+
self.distances_method = distances_method
|
|
384
|
+
self.loss_beta = loss_beta
|
|
385
|
+
self.device = device
|
|
386
|
+
|
|
387
|
+
if self.shared_codebook:
|
|
388
|
+
self.shared_vq = VQEmbedding(
|
|
389
|
+
self.codebook_size[0],
|
|
390
|
+
self.codebook_emb_dim,
|
|
391
|
+
self.kmeans_method,
|
|
392
|
+
self.kmeans_iters,
|
|
393
|
+
self.distances_method,
|
|
394
|
+
self.device,
|
|
395
|
+
)
|
|
396
|
+
self.vqmodules = None
|
|
397
|
+
else:
|
|
398
|
+
self.shared_vq = None
|
|
399
|
+
self.vqmodules = nn.ModuleList(
|
|
400
|
+
[
|
|
401
|
+
VQEmbedding(
|
|
402
|
+
self.codebook_size[idx],
|
|
403
|
+
self.codebook_emb_dim,
|
|
404
|
+
self.kmeans_method,
|
|
405
|
+
self.kmeans_iters,
|
|
406
|
+
self.distances_method,
|
|
407
|
+
self.device,
|
|
408
|
+
)
|
|
409
|
+
for idx in range(self.num_codebooks)
|
|
410
|
+
]
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
# get vq module for specific level
|
|
414
|
+
def vq(self, level: int) -> VQEmbedding:
|
|
415
|
+
if self.shared_codebook:
|
|
416
|
+
assert self.shared_vq is not None
|
|
417
|
+
return self.shared_vq
|
|
418
|
+
assert self.vqmodules is not None
|
|
419
|
+
return self.vqmodules[level]
|
|
420
|
+
|
|
421
|
+
# residual quantization process, transforms continuous data to discrete codes
|
|
422
|
+
def quantize(self, data: torch.Tensor):
|
|
423
|
+
r = data
|
|
424
|
+
z_q = torch.zeros_like(data)
|
|
425
|
+
|
|
426
|
+
r_in_list, q_list, zq_list, semantic_id_list = [], [], [], []
|
|
427
|
+
|
|
428
|
+
for i in range(self.num_codebooks):
|
|
429
|
+
r_in = r # current residual
|
|
430
|
+
vq = self.vq(i) # get VQ module for current level
|
|
431
|
+
q, ids = vq(r_in) # q: quantized embedding, ids: semantic IDs
|
|
432
|
+
|
|
433
|
+
q_st = (
|
|
434
|
+
r_in + (q - r_in).detach()
|
|
435
|
+
) # **IMPORTANT** straight-through estimator, stop grad on r_in side
|
|
436
|
+
z_q = z_q + q_st # accumulate quantized embeddings
|
|
437
|
+
r = r - q_st # update residual
|
|
438
|
+
|
|
439
|
+
r_in_list.append(r_in)
|
|
440
|
+
q_list.append(q)
|
|
441
|
+
zq_list.append(z_q)
|
|
442
|
+
semantic_id_list.append(ids.unsqueeze(-1))
|
|
443
|
+
|
|
444
|
+
semantic_ids = torch.cat(semantic_id_list, dim=-1)
|
|
445
|
+
# zq_list: list of accumulated quantized embeddings at each level
|
|
446
|
+
# r_in_list: list of residuals before quantization at each level
|
|
447
|
+
# q_list: list of quantized embeddings at each level
|
|
448
|
+
# semantic_ids: [N, num_codebooks] discrete codebook indices
|
|
449
|
+
return zq_list, r_in_list, q_list, semantic_ids
|
|
450
|
+
|
|
451
|
+
def rqvae_loss(
|
|
452
|
+
self, r_in_list: list[torch.Tensor], q_list: list[torch.Tensor]
|
|
453
|
+
) -> torch.Tensor:
|
|
454
|
+
losses = []
|
|
455
|
+
for r_in, q in zip(r_in_list, q_list):
|
|
456
|
+
# codebook loss: move codebook towards encoder output (stop grad on encoder side)
|
|
457
|
+
codebook_loss = (q - r_in.detach()).pow(2.0).mean()
|
|
458
|
+
|
|
459
|
+
# commitment loss: encourage encoder outputs to commit to codebook (stop grad on codebook side)
|
|
460
|
+
commit_loss = (r_in - q.detach()).pow(2.0).mean()
|
|
461
|
+
|
|
462
|
+
losses.append(codebook_loss + self.loss_beta * commit_loss)
|
|
463
|
+
|
|
464
|
+
return torch.stack(losses).sum()
|
|
465
|
+
|
|
466
|
+
def forward(
|
|
467
|
+
self, data: torch.Tensor
|
|
468
|
+
) -> tuple[list[torch.Tensor], torch.Tensor, torch.Tensor]:
|
|
469
|
+
|
|
470
|
+
zq_list, r_in_list, q_list, semantic_ids = self.quantize(data)
|
|
471
|
+
rq_loss = self.rqvae_loss(r_in_list, q_list)
|
|
472
|
+
return zq_list, semantic_ids, rq_loss
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
# RQ-VAE Model
|
|
476
|
+
class RQVAE(BaseModel):
|
|
477
|
+
|
|
478
|
+
@property
|
|
479
|
+
def model_name(self) -> str:
|
|
480
|
+
return "RQVAE"
|
|
481
|
+
|
|
482
|
+
@property
|
|
483
|
+
def default_task(self) -> str:
|
|
484
|
+
# task is unused for unsupervised training, keep a valid default for BaseModel
|
|
485
|
+
return "regression"
|
|
486
|
+
|
|
487
|
+
def __init__(
|
|
488
|
+
self,
|
|
489
|
+
input_dim: int,
|
|
490
|
+
hidden_dims: list,
|
|
491
|
+
latent_dim: int,
|
|
492
|
+
num_codebooks: int,
|
|
493
|
+
codebook_size: list,
|
|
494
|
+
shared_codebook: bool,
|
|
495
|
+
kmeans_method,
|
|
496
|
+
kmeans_iters,
|
|
497
|
+
distances_method,
|
|
498
|
+
loss_beta: float,
|
|
499
|
+
device: str,
|
|
500
|
+
dense_features: list[DenseFeature] | None = None,
|
|
501
|
+
target: str | list[str] | None = None,
|
|
502
|
+
**kwargs,
|
|
503
|
+
):
|
|
504
|
+
|
|
505
|
+
self.input_dim = input_dim
|
|
506
|
+
self.latent_dim = latent_dim
|
|
507
|
+
self.num_codebooks = num_codebooks
|
|
508
|
+
self.codebook_size = codebook_size
|
|
509
|
+
self.loss_beta = loss_beta
|
|
510
|
+
|
|
511
|
+
super().__init__(
|
|
512
|
+
dense_features=dense_features,
|
|
513
|
+
sparse_features=None,
|
|
514
|
+
sequence_features=None,
|
|
515
|
+
target=target,
|
|
516
|
+
task=self.default_task,
|
|
517
|
+
device=device,
|
|
518
|
+
**kwargs,
|
|
519
|
+
)
|
|
520
|
+
|
|
521
|
+
self.encoder = RQEncoder(input_dim, hidden_dims, latent_dim).to(self.device)
|
|
522
|
+
self.decoder = RQDecoder(latent_dim, hidden_dims[::-1], input_dim).to(
|
|
523
|
+
self.device
|
|
524
|
+
)
|
|
525
|
+
self.rq = RQ(
|
|
526
|
+
num_codebooks,
|
|
527
|
+
codebook_size,
|
|
528
|
+
latent_dim,
|
|
529
|
+
shared_codebook,
|
|
530
|
+
kmeans_method,
|
|
531
|
+
kmeans_iters,
|
|
532
|
+
distances_method,
|
|
533
|
+
loss_beta,
|
|
534
|
+
self.device,
|
|
535
|
+
).to(self.device)
|
|
536
|
+
|
|
537
|
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
538
|
+
|
|
539
|
+
return self.encoder(x)
|
|
540
|
+
|
|
541
|
+
def decode(self, z_vq: torch.Tensor | list[torch.Tensor]) -> torch.Tensor:
|
|
542
|
+
|
|
543
|
+
if isinstance(z_vq, list):
|
|
544
|
+
z_vq = z_vq[-1]
|
|
545
|
+
return self.decoder(z_vq)
|
|
546
|
+
|
|
547
|
+
def compute_loss(
|
|
548
|
+
self, x_hat: torch.Tensor, x_gt: torch.Tensor, rqvae_loss: torch.Tensor
|
|
549
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
550
|
+
|
|
551
|
+
recon_loss = F.mse_loss(x_hat, x_gt, reduction="mean")
|
|
552
|
+
total_loss = recon_loss + rqvae_loss
|
|
553
|
+
return recon_loss, rqvae_loss, total_loss
|
|
554
|
+
|
|
555
|
+
def prepare_loader(
|
|
556
|
+
self,
|
|
557
|
+
data: DataLoader | dict | list | tuple,
|
|
558
|
+
batch_size: int,
|
|
559
|
+
shuffle: bool,
|
|
560
|
+
num_workers: int,
|
|
561
|
+
) -> DataLoader:
|
|
562
|
+
if isinstance(data, DataLoader):
|
|
563
|
+
return data
|
|
564
|
+
dataloader = self.prepare_data_loader(
|
|
565
|
+
data=data,
|
|
566
|
+
batch_size=batch_size,
|
|
567
|
+
shuffle=shuffle,
|
|
568
|
+
num_workers=num_workers,
|
|
569
|
+
)
|
|
570
|
+
if isinstance(dataloader, tuple):
|
|
571
|
+
loader = dataloader[0]
|
|
572
|
+
else:
|
|
573
|
+
loader = dataloader
|
|
574
|
+
return cast(DataLoader, loader)
|
|
575
|
+
|
|
576
|
+
# extract input embeddings from batch data
|
|
577
|
+
def extract_embeddings(self, batch_data) -> torch.Tensor:
|
|
578
|
+
|
|
579
|
+
batch_dict = batch_to_dict(batch_data)
|
|
580
|
+
X_input, _ = self.get_input(batch_dict, require_labels=False)
|
|
581
|
+
|
|
582
|
+
if not self.all_features:
|
|
583
|
+
raise ValueError(
|
|
584
|
+
"[RQVAE] dense_features are required to use fit/predict helpers."
|
|
585
|
+
)
|
|
586
|
+
tensors: list[torch.Tensor] = []
|
|
587
|
+
for name in self.feature_names:
|
|
588
|
+
if name not in X_input:
|
|
589
|
+
raise KeyError(
|
|
590
|
+
f"[RQVAE] Feature '{name}' not found in input batch. Available keys: {list(X_input.keys())}"
|
|
591
|
+
)
|
|
592
|
+
tensors.append(X_input[name].to(self.device).float())
|
|
593
|
+
if not tensors:
|
|
594
|
+
raise ValueError("[RQVAE] No feature tensors found in batch.")
|
|
595
|
+
init_embedding = tensors[0] if len(tensors) == 1 else torch.cat(tensors, dim=-1)
|
|
596
|
+
if init_embedding.shape[-1] != self.input_dim:
|
|
597
|
+
raise ValueError(
|
|
598
|
+
f"[RQVAE] Input dim mismatch: expected {self.input_dim}, got {init_embedding.shape[-1]}."
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
return init_embedding
|
|
602
|
+
|
|
603
|
+
def init_codebook(self, train_loader: DataLoader, init_batches: int) -> None:
|
|
604
|
+
cached: list[torch.Tensor] = []
|
|
605
|
+
for batch_idx, batch in enumerate(train_loader):
|
|
606
|
+
cached.append(self.extract_embeddings(batch))
|
|
607
|
+
if batch_idx >= init_batches - 1:
|
|
608
|
+
break
|
|
609
|
+
if not cached:
|
|
610
|
+
raise ValueError("[RQVAE] No data available for codebook initialization.")
|
|
611
|
+
|
|
612
|
+
init_data = torch.cat(cached, dim=0)
|
|
613
|
+
|
|
614
|
+
with torch.no_grad():
|
|
615
|
+
# Encode to latent space, [num_samples, latent_dim]
|
|
616
|
+
z_e = self.encode(init_data)
|
|
617
|
+
|
|
618
|
+
r = z_e # current residual
|
|
619
|
+
|
|
620
|
+
for level in range(self.num_codebooks):
|
|
621
|
+
vq = self.rq.vq(level)
|
|
622
|
+
if not vq.codebook_initialized:
|
|
623
|
+
vq.create_codebook(r)
|
|
624
|
+
q, _ = vq(r) # quantize current residual
|
|
625
|
+
r = r - q # update residual
|
|
626
|
+
|
|
627
|
+
def get_semantic_ids(self, x_gt: torch.Tensor) -> torch.Tensor:
|
|
628
|
+
z_e = self.encode(x_gt) # encode source input to latent space
|
|
629
|
+
vq_emb_list, semantic_id_list, rqvae_loss = self.rq(z_e)
|
|
630
|
+
return semantic_id_list
|
|
631
|
+
|
|
632
|
+
def forward(
|
|
633
|
+
self, x_gt: torch.Tensor
|
|
634
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
635
|
+
z_e = self.encode(x_gt)
|
|
636
|
+
vq_emb_list, semantic_id_list, rqvae_loss = self.rq(z_e)
|
|
637
|
+
x_hat = self.decode(vq_emb_list)
|
|
638
|
+
recon_loss, rqvae_loss, total_loss = self.compute_loss(x_hat, x_gt, rqvae_loss)
|
|
639
|
+
return x_hat, semantic_id_list, recon_loss, rqvae_loss, total_loss
|
|
640
|
+
|
|
641
|
+
def fit(
|
|
642
|
+
self,
|
|
643
|
+
train_data: DataLoader | dict | list | tuple,
|
|
644
|
+
valid_data: DataLoader | dict | list | tuple | None = None,
|
|
645
|
+
epochs: int = 1,
|
|
646
|
+
batch_size: int = 256,
|
|
647
|
+
shuffle: bool = True,
|
|
648
|
+
num_workers: int = 0,
|
|
649
|
+
lr: float = 1e-3,
|
|
650
|
+
init_batches: int = 3,
|
|
651
|
+
):
|
|
652
|
+
"""
|
|
653
|
+
Train RQ-VAE.
|
|
654
|
+
|
|
655
|
+
Args:
|
|
656
|
+
train_data: Training data (DataLoader, dict, or array-like) that matches dense_features.
|
|
657
|
+
valid_data: Optional validation data for monitoring loss.
|
|
658
|
+
epochs: Training epochs.
|
|
659
|
+
batch_size: Batch size for building DataLoader when raw data is provided.
|
|
660
|
+
shuffle: Shuffle training data when constructing a DataLoader.
|
|
661
|
+
num_workers: Number of DataLoader workers.
|
|
662
|
+
lr: Learning rate for Adam optimizer.
|
|
663
|
+
init_batches: Number of batches used to initialize the codebook.
|
|
664
|
+
"""
|
|
665
|
+
train_loader = self.prepare_loader(
|
|
666
|
+
data=train_data,
|
|
667
|
+
batch_size=batch_size,
|
|
668
|
+
shuffle=shuffle,
|
|
669
|
+
num_workers=num_workers,
|
|
670
|
+
)
|
|
671
|
+
valid_loader = (
|
|
672
|
+
self.prepare_loader(
|
|
673
|
+
data=valid_data,
|
|
674
|
+
batch_size=batch_size,
|
|
675
|
+
shuffle=False,
|
|
676
|
+
num_workers=num_workers,
|
|
677
|
+
)
|
|
678
|
+
if valid_data is not None
|
|
679
|
+
else None
|
|
680
|
+
)
|
|
681
|
+
|
|
682
|
+
if not self.logger_initialized and self.is_main_process:
|
|
683
|
+
setup_logger(session_id=self.session_id)
|
|
684
|
+
self.logger_initialized = True
|
|
685
|
+
|
|
686
|
+
# Minimal placeholders to satisfy BaseModel.summary when running unsupervised
|
|
687
|
+
if not hasattr(self, "metrics"):
|
|
688
|
+
self.metrics = ["loss"]
|
|
689
|
+
if not hasattr(self, "task_specific_metrics"):
|
|
690
|
+
self.task_specific_metrics = {}
|
|
691
|
+
if not hasattr(self, "best_metrics_mode"):
|
|
692
|
+
self.best_metrics_mode = "min"
|
|
693
|
+
|
|
694
|
+
self.to(self.device)
|
|
695
|
+
optimizer = torch.optim.Adam(self.parameters(), lr=lr)
|
|
696
|
+
self.train()
|
|
697
|
+
|
|
698
|
+
try:
|
|
699
|
+
steps_per_epoch = len(train_loader)
|
|
700
|
+
is_streaming = False
|
|
701
|
+
except TypeError:
|
|
702
|
+
steps_per_epoch = None
|
|
703
|
+
is_streaming = True
|
|
704
|
+
|
|
705
|
+
self.init_codebook(train_loader, init_batches=init_batches)
|
|
706
|
+
|
|
707
|
+
if self.is_main_process:
|
|
708
|
+
self.summary()
|
|
709
|
+
logging.info("")
|
|
710
|
+
logging.info(colorize("=" * 80, bold=True))
|
|
711
|
+
logging.info(
|
|
712
|
+
colorize(
|
|
713
|
+
"Start streaming training" if is_streaming else "Start training",
|
|
714
|
+
bold=True,
|
|
715
|
+
)
|
|
716
|
+
)
|
|
717
|
+
logging.info(colorize("=" * 80, bold=True))
|
|
718
|
+
logging.info("")
|
|
719
|
+
logging.info(colorize(f"Model device: {self.device}", bold=True))
|
|
720
|
+
|
|
721
|
+
for epoch in range(epochs):
|
|
722
|
+
total_loss = 0.0
|
|
723
|
+
step_count = 0
|
|
724
|
+
if is_streaming and self.is_main_process:
|
|
725
|
+
logging.info("")
|
|
726
|
+
logging.info(colorize(f"Epoch {epoch + 1}/{epochs}", bold=True))
|
|
727
|
+
if is_streaming:
|
|
728
|
+
batch_iter = enumerate(train_loader)
|
|
729
|
+
else:
|
|
730
|
+
tqdm_disable = not self.is_main_process
|
|
731
|
+
batch_iter = enumerate(
|
|
732
|
+
tqdm.tqdm(
|
|
733
|
+
train_loader,
|
|
734
|
+
desc=f"Epoch {epoch + 1}/{epochs}",
|
|
735
|
+
total=steps_per_epoch,
|
|
736
|
+
disable=tqdm_disable,
|
|
737
|
+
)
|
|
738
|
+
)
|
|
739
|
+
for _, batch in batch_iter:
|
|
740
|
+
embeddings = self.extract_embeddings(batch)
|
|
741
|
+
_, _, recon_loss, rqvae_loss, total_batch_loss = self(embeddings)
|
|
742
|
+
|
|
743
|
+
optimizer.zero_grad()
|
|
744
|
+
total_batch_loss.backward()
|
|
745
|
+
optimizer.step()
|
|
746
|
+
|
|
747
|
+
total_loss += total_batch_loss.item()
|
|
748
|
+
step_count += 1
|
|
749
|
+
|
|
750
|
+
denom = steps_per_epoch if steps_per_epoch is not None else step_count
|
|
751
|
+
avg_loss = total_loss / max(1, denom)
|
|
752
|
+
train_log = f"Epoch {epoch + 1}/{epochs} - Train Loss: {avg_loss:.4f}"
|
|
753
|
+
|
|
754
|
+
if valid_loader is not None:
|
|
755
|
+
val_total = 0.0
|
|
756
|
+
val_steps = 0
|
|
757
|
+
with torch.no_grad():
|
|
758
|
+
for batch in valid_loader:
|
|
759
|
+
embeddings = self.extract_embeddings(batch)
|
|
760
|
+
_, _, _, _, val_loss = self(embeddings)
|
|
761
|
+
val_total += val_loss.item()
|
|
762
|
+
val_steps += 1
|
|
763
|
+
try:
|
|
764
|
+
val_denom = len(valid_loader)
|
|
765
|
+
except TypeError:
|
|
766
|
+
val_denom = val_steps
|
|
767
|
+
val_avg = val_total / max(1, val_denom)
|
|
768
|
+
if self.is_main_process:
|
|
769
|
+
logging.info(colorize(train_log))
|
|
770
|
+
logging.info(
|
|
771
|
+
colorize(
|
|
772
|
+
f" Epoch {epoch + 1}/{epochs} - Valid Loss: {val_avg:.4f}",
|
|
773
|
+
color="cyan",
|
|
774
|
+
)
|
|
775
|
+
)
|
|
776
|
+
elif self.is_main_process:
|
|
777
|
+
logging.info(colorize(train_log))
|
|
778
|
+
|
|
779
|
+
if self.is_main_process:
|
|
780
|
+
logging.info(" ")
|
|
781
|
+
logging.info(colorize("Training finished.", bold=True))
|
|
782
|
+
logging.info(" ")
|
|
783
|
+
return self
|
|
784
|
+
|
|
785
|
+
def predict(
|
|
786
|
+
self,
|
|
787
|
+
data: DataLoader | dict | list | tuple,
|
|
788
|
+
batch_size: int = 256,
|
|
789
|
+
num_workers: int = 0,
|
|
790
|
+
return_reconstruction: bool = False,
|
|
791
|
+
as_numpy: bool = True,
|
|
792
|
+
) -> torch.Tensor:
|
|
793
|
+
"""
|
|
794
|
+
Generate semantic IDs or reconstructed embeddings.
|
|
795
|
+
|
|
796
|
+
Args:
|
|
797
|
+
data: Input data aligned with dense_features.
|
|
798
|
+
batch_size: Batch size for building DataLoader when raw data is provided.
|
|
799
|
+
num_workers: Number of DataLoader workers.
|
|
800
|
+
return_reconstruction: If True, return reconstructed embeddings; otherwise, return semantic IDs.
|
|
801
|
+
as_numpy: Whether to return a NumPy array; if False, returns a torch.Tensor on CPU.
|
|
802
|
+
"""
|
|
803
|
+
data_loader = self.prepare_loader(
|
|
804
|
+
data=data,
|
|
805
|
+
batch_size=batch_size,
|
|
806
|
+
shuffle=False,
|
|
807
|
+
num_workers=num_workers,
|
|
808
|
+
)
|
|
809
|
+
outputs: list[torch.Tensor] = []
|
|
810
|
+
self.eval()
|
|
811
|
+
with torch.no_grad():
|
|
812
|
+
for batch in data_loader:
|
|
813
|
+
embeddings = self.extract_embeddings(batch)
|
|
814
|
+
if return_reconstruction:
|
|
815
|
+
x_hat, _, _, _, _ = self(embeddings)
|
|
816
|
+
outputs.append(x_hat.detach().cpu())
|
|
817
|
+
else:
|
|
818
|
+
semantic_ids = self.get_semantic_ids(embeddings)
|
|
819
|
+
outputs.append(semantic_ids.detach().cpu())
|
|
820
|
+
|
|
821
|
+
if outputs:
|
|
822
|
+
result = torch.cat(outputs, dim=0)
|
|
823
|
+
else:
|
|
824
|
+
out_dim = self.input_dim if return_reconstruction else self.num_codebooks
|
|
825
|
+
result = torch.empty((0, out_dim))
|
|
826
|
+
return result.numpy() if as_numpy else result
|