torch-rechub 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/__init__.py +14 -0
- torch_rechub/basic/activation.py +54 -54
- torch_rechub/basic/callback.py +33 -33
- torch_rechub/basic/features.py +87 -94
- torch_rechub/basic/initializers.py +92 -92
- torch_rechub/basic/layers.py +994 -720
- torch_rechub/basic/loss_func.py +223 -34
- torch_rechub/basic/metaoptimizer.py +76 -72
- torch_rechub/basic/metric.py +251 -250
- torch_rechub/models/generative/__init__.py +6 -0
- torch_rechub/models/generative/hllm.py +249 -0
- torch_rechub/models/generative/hstu.py +189 -0
- torch_rechub/models/matching/__init__.py +13 -11
- torch_rechub/models/matching/comirec.py +193 -188
- torch_rechub/models/matching/dssm.py +72 -66
- torch_rechub/models/matching/dssm_facebook.py +77 -79
- torch_rechub/models/matching/dssm_senet.py +28 -16
- torch_rechub/models/matching/gru4rec.py +85 -87
- torch_rechub/models/matching/mind.py +103 -101
- torch_rechub/models/matching/narm.py +82 -76
- torch_rechub/models/matching/sasrec.py +143 -140
- torch_rechub/models/matching/sine.py +148 -151
- torch_rechub/models/matching/stamp.py +81 -83
- torch_rechub/models/matching/youtube_dnn.py +75 -71
- torch_rechub/models/matching/youtube_sbc.py +98 -98
- torch_rechub/models/multi_task/__init__.py +7 -5
- torch_rechub/models/multi_task/aitm.py +83 -84
- torch_rechub/models/multi_task/esmm.py +56 -55
- torch_rechub/models/multi_task/mmoe.py +58 -58
- torch_rechub/models/multi_task/ple.py +116 -130
- torch_rechub/models/multi_task/shared_bottom.py +45 -45
- torch_rechub/models/ranking/__init__.py +14 -11
- torch_rechub/models/ranking/afm.py +65 -63
- torch_rechub/models/ranking/autoint.py +102 -0
- torch_rechub/models/ranking/bst.py +61 -63
- torch_rechub/models/ranking/dcn.py +38 -38
- torch_rechub/models/ranking/dcn_v2.py +59 -69
- torch_rechub/models/ranking/deepffm.py +131 -123
- torch_rechub/models/ranking/deepfm.py +43 -42
- torch_rechub/models/ranking/dien.py +191 -191
- torch_rechub/models/ranking/din.py +93 -91
- torch_rechub/models/ranking/edcn.py +101 -117
- torch_rechub/models/ranking/fibinet.py +42 -50
- torch_rechub/models/ranking/widedeep.py +41 -41
- torch_rechub/trainers/__init__.py +4 -3
- torch_rechub/trainers/ctr_trainer.py +288 -128
- torch_rechub/trainers/match_trainer.py +336 -170
- torch_rechub/trainers/matching.md +3 -0
- torch_rechub/trainers/mtl_trainer.py +356 -207
- torch_rechub/trainers/seq_trainer.py +427 -0
- torch_rechub/utils/data.py +492 -360
- torch_rechub/utils/hstu_utils.py +198 -0
- torch_rechub/utils/match.py +457 -274
- torch_rechub/utils/model_utils.py +233 -0
- torch_rechub/utils/mtl.py +136 -126
- torch_rechub/utils/onnx_export.py +220 -0
- torch_rechub/utils/visualization.py +271 -0
- torch_rechub-0.0.5.dist-info/METADATA +402 -0
- torch_rechub-0.0.5.dist-info/RECORD +64 -0
- {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info}/WHEEL +1 -2
- {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info/licenses}/LICENSE +21 -21
- torch_rechub-0.0.3.dist-info/METADATA +0 -177
- torch_rechub-0.0.3.dist-info/RECORD +0 -55
- torch_rechub-0.0.3.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
"""HLLM: Hierarchical Large Language Model for Recommendation."""
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
|
|
9
|
+
from torch_rechub.utils.hstu_utils import RelPosBias
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class HLLMTransformerBlock(nn.Module):
|
|
13
|
+
"""Single HLLM Transformer block with self-attention and FFN.
|
|
14
|
+
|
|
15
|
+
This block is similar to HSTULayer but designed for HLLM which uses
|
|
16
|
+
pre-computed item embeddings as input instead of learnable token embeddings.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
d_model (int): Hidden dimension.
|
|
20
|
+
n_heads (int): Number of attention heads.
|
|
21
|
+
dropout (float): Dropout rate.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, d_model=512, n_heads=8, dropout=0.1):
|
|
25
|
+
super().__init__()
|
|
26
|
+
self.d_model = d_model
|
|
27
|
+
self.n_heads = n_heads
|
|
28
|
+
|
|
29
|
+
assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
|
|
30
|
+
self.head_dim = d_model // n_heads
|
|
31
|
+
self.scale = self.head_dim**-0.5
|
|
32
|
+
|
|
33
|
+
# Multi-head self-attention
|
|
34
|
+
self.W_Q = nn.Linear(d_model, d_model)
|
|
35
|
+
self.W_K = nn.Linear(d_model, d_model)
|
|
36
|
+
self.W_V = nn.Linear(d_model, d_model)
|
|
37
|
+
self.W_O = nn.Linear(d_model, d_model)
|
|
38
|
+
|
|
39
|
+
# Feed-forward network
|
|
40
|
+
ffn_hidden = 4 * d_model
|
|
41
|
+
self.ffn = nn.Sequential(nn.Linear(d_model, ffn_hidden), nn.ReLU(), nn.Dropout(dropout), nn.Linear(ffn_hidden, d_model), nn.Dropout(dropout))
|
|
42
|
+
|
|
43
|
+
# Layer normalization and dropout
|
|
44
|
+
self.norm1 = nn.LayerNorm(d_model)
|
|
45
|
+
self.norm2 = nn.LayerNorm(d_model)
|
|
46
|
+
self.dropout = nn.Dropout(dropout)
|
|
47
|
+
|
|
48
|
+
def forward(self, x, rel_pos_bias=None):
|
|
49
|
+
"""Forward pass.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
x (Tensor): Input of shape (B, L, D).
|
|
53
|
+
rel_pos_bias (Tensor, optional): Relative position bias.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
Tensor: Output of shape (B, L, D).
|
|
57
|
+
"""
|
|
58
|
+
batch_size, seq_len, _ = x.shape
|
|
59
|
+
|
|
60
|
+
# Self-attention with residual
|
|
61
|
+
residual = x
|
|
62
|
+
x = self.norm1(x)
|
|
63
|
+
|
|
64
|
+
Q = self.W_Q(x) # (B, L, D)
|
|
65
|
+
K = self.W_K(x)
|
|
66
|
+
V = self.W_V(x)
|
|
67
|
+
|
|
68
|
+
# Reshape for multi-head attention
|
|
69
|
+
Q = Q.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
|
|
70
|
+
K = K.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
|
|
71
|
+
V = V.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
|
|
72
|
+
|
|
73
|
+
# Attention scores
|
|
74
|
+
scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
|
|
75
|
+
|
|
76
|
+
# Causal mask
|
|
77
|
+
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool))
|
|
78
|
+
scores = scores.masked_fill(~causal_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
|
|
79
|
+
|
|
80
|
+
# Add relative position bias if provided
|
|
81
|
+
if rel_pos_bias is not None:
|
|
82
|
+
scores = scores + rel_pos_bias
|
|
83
|
+
|
|
84
|
+
attn_weights = F.softmax(scores, dim=-1)
|
|
85
|
+
attn_weights = self.dropout(attn_weights)
|
|
86
|
+
|
|
87
|
+
attn_output = torch.matmul(attn_weights, V)
|
|
88
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
89
|
+
attn_output = attn_output.view(batch_size, seq_len, self.d_model)
|
|
90
|
+
attn_output = self.W_O(attn_output)
|
|
91
|
+
attn_output = self.dropout(attn_output)
|
|
92
|
+
|
|
93
|
+
x = residual + attn_output
|
|
94
|
+
|
|
95
|
+
# FFN with residual
|
|
96
|
+
residual = x
|
|
97
|
+
x = self.norm2(x)
|
|
98
|
+
x = self.ffn(x)
|
|
99
|
+
x = residual + x
|
|
100
|
+
|
|
101
|
+
return x
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class HLLMModel(nn.Module):
|
|
105
|
+
"""HLLM: Hierarchical Large Language Model for Recommendation.
|
|
106
|
+
|
|
107
|
+
This is a lightweight implementation of HLLM that uses pre-computed item
|
|
108
|
+
embeddings as input. The original ByteDance HLLM uses end-to-end training
|
|
109
|
+
with both Item LLM and User LLM, but this implementation focuses on the
|
|
110
|
+
User LLM component for resource efficiency.
|
|
111
|
+
|
|
112
|
+
Architecture:
|
|
113
|
+
- Item Embeddings: Pre-computed using LLM (offline, frozen)
|
|
114
|
+
Format: "{item_prompt}title: {title}description: {description}"
|
|
115
|
+
where item_prompt = "Compress the following sentence into embedding: "
|
|
116
|
+
- User LLM: Transformer blocks that model user sequences (trainable)
|
|
117
|
+
- Scoring Head: Dot product between user representation and item embeddings
|
|
118
|
+
|
|
119
|
+
Reference:
|
|
120
|
+
ByteDance HLLM: https://github.com/bytedance/HLLM
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
item_embeddings (Tensor or str): Pre-computed item embeddings of shape
|
|
124
|
+
(vocab_size, d_model), or path to a .pt file containing embeddings.
|
|
125
|
+
Generated using the last token's hidden state from an LLM.
|
|
126
|
+
vocab_size (int): Vocabulary size (number of items).
|
|
127
|
+
d_model (int): Hidden dimension. Should match item embedding dimension.
|
|
128
|
+
Default: 512. TinyLlama uses 2048, Baichuan2 uses 4096.
|
|
129
|
+
n_heads (int): Number of attention heads. Default: 8.
|
|
130
|
+
n_layers (int): Number of transformer blocks. Default: 4.
|
|
131
|
+
max_seq_len (int): Maximum sequence length. Default: 256.
|
|
132
|
+
Official uses MAX_ITEM_LIST_LENGTH=50.
|
|
133
|
+
dropout (float): Dropout rate. Default: 0.1.
|
|
134
|
+
use_rel_pos_bias (bool): Whether to use relative position bias. Default: True.
|
|
135
|
+
use_time_embedding (bool): Whether to use time embeddings. Default: True.
|
|
136
|
+
num_time_buckets (int): Number of time buckets. Default: 2048.
|
|
137
|
+
time_bucket_fn (str): Time bucketization function ('sqrt' or 'log'). Default: 'sqrt'.
|
|
138
|
+
temperature (float): Temperature for NCE scoring. Default: 1.0.
|
|
139
|
+
Official uses logit_scale = log(1/0.07) ≈ 2.66.
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
def __init__(self, item_embeddings, vocab_size, d_model=512, n_heads=8, n_layers=4, max_seq_len=256, dropout=0.1, use_rel_pos_bias=True, use_time_embedding=True, num_time_buckets=2048, time_bucket_fn='sqrt', temperature=1.0):
|
|
143
|
+
super().__init__()
|
|
144
|
+
self.vocab_size = vocab_size
|
|
145
|
+
self.d_model = d_model
|
|
146
|
+
self.n_heads = n_heads
|
|
147
|
+
self.n_layers = n_layers
|
|
148
|
+
self.max_seq_len = max_seq_len
|
|
149
|
+
self.use_time_embedding = use_time_embedding
|
|
150
|
+
self.num_time_buckets = num_time_buckets
|
|
151
|
+
self.time_bucket_fn = time_bucket_fn
|
|
152
|
+
self.temperature = temperature
|
|
153
|
+
|
|
154
|
+
# Load item embeddings
|
|
155
|
+
if isinstance(item_embeddings, str):
|
|
156
|
+
item_embeddings = torch.load(item_embeddings)
|
|
157
|
+
|
|
158
|
+
# Register as buffer (not trainable)
|
|
159
|
+
self.register_buffer('item_embeddings', item_embeddings.float())
|
|
160
|
+
|
|
161
|
+
# Positional embedding
|
|
162
|
+
self.position_embedding = nn.Embedding(max_seq_len, d_model)
|
|
163
|
+
|
|
164
|
+
# Time embedding
|
|
165
|
+
if use_time_embedding:
|
|
166
|
+
self.time_embedding = nn.Embedding(num_time_buckets + 1, d_model, padding_idx=0)
|
|
167
|
+
|
|
168
|
+
# Transformer blocks
|
|
169
|
+
self.transformer_blocks = nn.ModuleList([HLLMTransformerBlock(d_model=d_model, n_heads=n_heads, dropout=dropout) for _ in range(n_layers)])
|
|
170
|
+
|
|
171
|
+
# Relative position bias
|
|
172
|
+
self.use_rel_pos_bias = use_rel_pos_bias
|
|
173
|
+
if use_rel_pos_bias:
|
|
174
|
+
self.rel_pos_bias = RelPosBias(n_heads, max_seq_len)
|
|
175
|
+
|
|
176
|
+
# Dropout
|
|
177
|
+
self.dropout = nn.Dropout(dropout)
|
|
178
|
+
|
|
179
|
+
self._init_weights()
|
|
180
|
+
|
|
181
|
+
def _init_weights(self):
|
|
182
|
+
"""Initialize model parameters."""
|
|
183
|
+
for name, param in self.named_parameters():
|
|
184
|
+
if 'weight' in name and len(param.shape) > 1:
|
|
185
|
+
nn.init.xavier_uniform_(param)
|
|
186
|
+
elif 'bias' in name:
|
|
187
|
+
nn.init.constant_(param, 0)
|
|
188
|
+
|
|
189
|
+
def _time_diff_to_bucket(self, time_diffs):
|
|
190
|
+
"""Map time differences to bucket indices."""
|
|
191
|
+
time_diffs = time_diffs.float() / 60.0 # seconds to minutes
|
|
192
|
+
time_diffs = torch.clamp(time_diffs, min=1e-6)
|
|
193
|
+
|
|
194
|
+
if self.time_bucket_fn == 'sqrt':
|
|
195
|
+
buckets = torch.sqrt(time_diffs).long()
|
|
196
|
+
elif self.time_bucket_fn == 'log':
|
|
197
|
+
buckets = torch.log(time_diffs).long()
|
|
198
|
+
else:
|
|
199
|
+
raise ValueError(f"Unsupported time_bucket_fn: {self.time_bucket_fn}")
|
|
200
|
+
|
|
201
|
+
buckets = torch.clamp(buckets, min=0, max=self.num_time_buckets - 1)
|
|
202
|
+
return buckets
|
|
203
|
+
|
|
204
|
+
def forward(self, seq_tokens, time_diffs=None):
|
|
205
|
+
"""Forward pass.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
seq_tokens (Tensor): Item token IDs of shape (B, L).
|
|
209
|
+
time_diffs (Tensor, optional): Time differences in seconds of shape (B, L).
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
Tensor: Logits of shape (B, L, vocab_size).
|
|
213
|
+
"""
|
|
214
|
+
batch_size, seq_len = seq_tokens.shape
|
|
215
|
+
|
|
216
|
+
# Look up item embeddings
|
|
217
|
+
item_emb = self.item_embeddings[seq_tokens] # (B, L, D)
|
|
218
|
+
|
|
219
|
+
# Add positional embedding
|
|
220
|
+
positions = torch.arange(seq_len, dtype=torch.long, device=seq_tokens.device)
|
|
221
|
+
pos_emb = self.position_embedding(positions) # (L, D)
|
|
222
|
+
embeddings = item_emb + pos_emb.unsqueeze(0) # (B, L, D)
|
|
223
|
+
|
|
224
|
+
# Add time embedding if provided
|
|
225
|
+
if self.use_time_embedding:
|
|
226
|
+
if time_diffs is None:
|
|
227
|
+
time_diffs = torch.zeros(batch_size, seq_len, dtype=torch.long, device=seq_tokens.device)
|
|
228
|
+
|
|
229
|
+
time_buckets = self._time_diff_to_bucket(time_diffs)
|
|
230
|
+
time_emb = self.time_embedding(time_buckets) # (B, L, D)
|
|
231
|
+
embeddings = embeddings + time_emb
|
|
232
|
+
|
|
233
|
+
embeddings = self.dropout(embeddings)
|
|
234
|
+
|
|
235
|
+
# Get relative position bias
|
|
236
|
+
rel_pos_bias = None
|
|
237
|
+
if self.use_rel_pos_bias:
|
|
238
|
+
rel_pos_bias = self.rel_pos_bias(seq_len)
|
|
239
|
+
|
|
240
|
+
# Pass through transformer blocks
|
|
241
|
+
x = embeddings
|
|
242
|
+
for block in self.transformer_blocks:
|
|
243
|
+
x = block(x, rel_pos_bias=rel_pos_bias)
|
|
244
|
+
|
|
245
|
+
# Scoring head: compute dot product with item embeddings
|
|
246
|
+
# x: (B, L, D), item_embeddings: (V, D)
|
|
247
|
+
logits = torch.matmul(x, self.item_embeddings.t()) / self.temperature # (B, L, V)
|
|
248
|
+
|
|
249
|
+
return logits
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
"""HSTU: Hierarchical Sequential Transduction Units Model."""
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
|
|
8
|
+
from torch_rechub.basic.layers import HSTUBlock
|
|
9
|
+
from torch_rechub.utils.hstu_utils import RelPosBias
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class HSTUModel(nn.Module):
|
|
13
|
+
"""HSTU: Hierarchical Sequential Transduction Units model.
|
|
14
|
+
|
|
15
|
+
Autoregressive generative recommendation model for sequential data.
|
|
16
|
+
This module stacks multiple ``HSTUBlock`` layers to capture long-range
|
|
17
|
+
dependencies in user interaction sequences and predicts the next item.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
vocab_size (int): Vocabulary size (number of distinct items, including PAD).
|
|
21
|
+
d_model (int): Hidden dimension of the model. Default: 512.
|
|
22
|
+
n_heads (int): Number of attention heads. Default: 8.
|
|
23
|
+
n_layers (int): Number of stacked HSTU layers. Default: 4.
|
|
24
|
+
dqk (int): Dimension of query/key vectors per head. Default: 64.
|
|
25
|
+
dv (int): Dimension of value vectors per head. Default: 64.
|
|
26
|
+
max_seq_len (int): Maximum sequence length. Default: 256.
|
|
27
|
+
dropout (float): Dropout rate applied in the model. Default: 0.1.
|
|
28
|
+
use_rel_pos_bias (bool): Whether to use relative position bias. Default: True.
|
|
29
|
+
use_time_embedding (bool): Whether to use time-difference embeddings. Default: True.
|
|
30
|
+
num_time_buckets (int): Number of time buckets for time embeddings. Default: 2048.
|
|
31
|
+
time_bucket_fn (str): Function used to bucketize time differences, ``"sqrt"``
|
|
32
|
+
or ``"log"``. Default: ``"sqrt"``.
|
|
33
|
+
|
|
34
|
+
Shape:
|
|
35
|
+
- Input: ``x`` of shape ``(batch_size, seq_len)``; optional ``time_diffs``
|
|
36
|
+
of shape ``(batch_size, seq_len)`` representing time differences in seconds.
|
|
37
|
+
- Output: Logits of shape ``(batch_size, seq_len, vocab_size)``.
|
|
38
|
+
|
|
39
|
+
Example:
|
|
40
|
+
>>> model = HSTUModel(vocab_size=100000, d_model=512)
|
|
41
|
+
>>> x = torch.randint(0, 100000, (32, 256))
|
|
42
|
+
>>> time_diffs = torch.randint(0, 86400, (32, 256))
|
|
43
|
+
>>> logits = model(x, time_diffs)
|
|
44
|
+
>>> logits.shape
|
|
45
|
+
torch.Size([32, 256, 100000])
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(self, vocab_size, d_model=512, n_heads=8, n_layers=4, dqk=64, dv=64, max_seq_len=256, dropout=0.1, use_rel_pos_bias=True, use_time_embedding=True, num_time_buckets=2048, time_bucket_fn='sqrt'):
|
|
49
|
+
super().__init__()
|
|
50
|
+
self.vocab_size = vocab_size
|
|
51
|
+
self.d_model = d_model
|
|
52
|
+
self.n_heads = n_heads
|
|
53
|
+
self.n_layers = n_layers
|
|
54
|
+
self.max_seq_len = max_seq_len
|
|
55
|
+
self.use_time_embedding = use_time_embedding
|
|
56
|
+
self.num_time_buckets = num_time_buckets
|
|
57
|
+
self.time_bucket_fn = time_bucket_fn
|
|
58
|
+
|
|
59
|
+
# Alpha scaling factor (following the Meta reference implementation)
|
|
60
|
+
self.alpha = math.sqrt(d_model)
|
|
61
|
+
|
|
62
|
+
# Token embedding
|
|
63
|
+
self.token_embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
|
|
64
|
+
|
|
65
|
+
# Absolute positional embedding
|
|
66
|
+
self.position_embedding = nn.Embedding(max_seq_len, d_model)
|
|
67
|
+
|
|
68
|
+
# Time embedding
|
|
69
|
+
if use_time_embedding:
|
|
70
|
+
# Embedding table for time-difference buckets
|
|
71
|
+
# num_time_buckets + 1: extra bucket reserved for padding
|
|
72
|
+
self.time_embedding = nn.Embedding(num_time_buckets + 1, d_model, padding_idx=0)
|
|
73
|
+
|
|
74
|
+
# HSTU block
|
|
75
|
+
self.hstu_block = HSTUBlock(d_model=d_model, n_heads=n_heads, n_layers=n_layers, dqk=dqk, dv=dv, dropout=dropout, use_rel_pos_bias=use_rel_pos_bias)
|
|
76
|
+
|
|
77
|
+
# Relative position bias
|
|
78
|
+
self.use_rel_pos_bias = use_rel_pos_bias
|
|
79
|
+
if use_rel_pos_bias:
|
|
80
|
+
self.rel_pos_bias = RelPosBias(n_heads, max_seq_len)
|
|
81
|
+
|
|
82
|
+
# 输出投影层
|
|
83
|
+
self.output_projection = nn.Linear(d_model, vocab_size)
|
|
84
|
+
|
|
85
|
+
# Dropout
|
|
86
|
+
self.dropout = nn.Dropout(dropout)
|
|
87
|
+
|
|
88
|
+
# Initialize model parameters
|
|
89
|
+
self._init_weights()
|
|
90
|
+
|
|
91
|
+
def _init_weights(self):
|
|
92
|
+
"""Initialize model parameters."""
|
|
93
|
+
for name, param in self.named_parameters():
|
|
94
|
+
if 'weight' in name and len(param.shape) > 1:
|
|
95
|
+
nn.init.xavier_uniform_(param)
|
|
96
|
+
elif 'bias' in name:
|
|
97
|
+
nn.init.constant_(param, 0)
|
|
98
|
+
|
|
99
|
+
def _time_diff_to_bucket(self, time_diffs):
|
|
100
|
+
"""Map raw time differences (in seconds) to discrete bucket indices.
|
|
101
|
+
|
|
102
|
+
Following the Meta HSTU implementation, continuous time differences are
|
|
103
|
+
first converted to minutes and then bucketized using either a square-root
|
|
104
|
+
or logarithmic transform.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
time_diffs (Tensor): Time differences in seconds,
|
|
108
|
+
shape ``(batch_size, seq_len)``.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
Tensor: Integer bucket indices of shape ``(batch_size, seq_len)``.
|
|
112
|
+
"""
|
|
113
|
+
# Convert seconds to minutes (as in the Meta reference implementation)
|
|
114
|
+
time_bucket_increments = 60.0
|
|
115
|
+
time_diffs = time_diffs.float() / time_bucket_increments
|
|
116
|
+
|
|
117
|
+
# Ensure non-negative values and avoid log(0)
|
|
118
|
+
time_diffs = torch.clamp(time_diffs, min=1e-6)
|
|
119
|
+
|
|
120
|
+
if self.time_bucket_fn == 'sqrt':
|
|
121
|
+
# Use the square-root transform: suitable when time differences
|
|
122
|
+
# are relatively evenly distributed.
|
|
123
|
+
buckets = torch.sqrt(time_diffs).long()
|
|
124
|
+
elif self.time_bucket_fn == 'log':
|
|
125
|
+
# Use the logarithmic transform: suitable when time differences
|
|
126
|
+
# span several orders of magnitude.
|
|
127
|
+
buckets = torch.log(time_diffs).long()
|
|
128
|
+
else:
|
|
129
|
+
raise ValueError(f"Unsupported time_bucket_fn: {self.time_bucket_fn}")
|
|
130
|
+
|
|
131
|
+
# Clamp bucket indices to the valid range [0, num_time_buckets - 1]
|
|
132
|
+
buckets = torch.clamp(buckets, min=0, max=self.num_time_buckets - 1)
|
|
133
|
+
|
|
134
|
+
return buckets
|
|
135
|
+
|
|
136
|
+
def forward(self, x, time_diffs=None):
|
|
137
|
+
"""Forward pass.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
x (Tensor): Input token ids of shape ``(batch_size, seq_len)``.
|
|
141
|
+
time_diffs (Tensor, optional): Time differences in seconds,
|
|
142
|
+
shape ``(batch_size, seq_len)``. If ``None`` and
|
|
143
|
+
``use_time_embedding=True``, all-zero time differences are used.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
Tensor: Logits over the vocabulary of shape
|
|
147
|
+
``(batch_size, seq_len, vocab_size)``.
|
|
148
|
+
"""
|
|
149
|
+
batch_size, seq_len = x.shape
|
|
150
|
+
|
|
151
|
+
# Token embedding with alpha scaling (as in the Meta implementation)
|
|
152
|
+
token_emb = self.token_embedding(x) * self.alpha # (B, L, D)
|
|
153
|
+
|
|
154
|
+
# Absolute positional embedding
|
|
155
|
+
positions = torch.arange(seq_len, dtype=torch.long, device=x.device)
|
|
156
|
+
pos_emb = self.position_embedding(positions) # (L, D)
|
|
157
|
+
|
|
158
|
+
# Combine token and position embeddings
|
|
159
|
+
embeddings = token_emb + pos_emb.unsqueeze(0) # (B, L, D)
|
|
160
|
+
|
|
161
|
+
# Optional time-difference embedding
|
|
162
|
+
if self.use_time_embedding:
|
|
163
|
+
if time_diffs is None:
|
|
164
|
+
# Fallback: use all-zero time differences when none are provided
|
|
165
|
+
time_diffs = torch.zeros(batch_size, seq_len, dtype=torch.long, device=x.device)
|
|
166
|
+
|
|
167
|
+
# Map raw time differences to bucket indices
|
|
168
|
+
time_buckets = self._time_diff_to_bucket(time_diffs) # (B, L)
|
|
169
|
+
|
|
170
|
+
# Look up time embeddings and add to the sequence representation
|
|
171
|
+
time_emb = self.time_embedding(time_buckets) # (B, L, D)
|
|
172
|
+
|
|
173
|
+
# embeddings = token_emb + pos_emb + time_emb
|
|
174
|
+
embeddings = embeddings + time_emb
|
|
175
|
+
|
|
176
|
+
embeddings = self.dropout(embeddings)
|
|
177
|
+
|
|
178
|
+
# Relative position bias for self-attention
|
|
179
|
+
rel_pos_bias = None
|
|
180
|
+
if self.use_rel_pos_bias:
|
|
181
|
+
rel_pos_bias = self.rel_pos_bias(seq_len) # (1, H, L, L)
|
|
182
|
+
|
|
183
|
+
# HSTU block
|
|
184
|
+
hstu_output = self.hstu_block(embeddings, rel_pos_bias=rel_pos_bias) # (B, L, D)
|
|
185
|
+
|
|
186
|
+
# Final projection to vocabulary logits
|
|
187
|
+
logits = self.output_projection(hstu_output) # (B, L, V)
|
|
188
|
+
|
|
189
|
+
return logits
|
|
@@ -1,11 +1,13 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
from .
|
|
4
|
-
from .
|
|
5
|
-
from .
|
|
6
|
-
from .
|
|
7
|
-
from .mind import MIND
|
|
8
|
-
from .narm import NARM
|
|
9
|
-
from .
|
|
10
|
-
from .
|
|
11
|
-
from .
|
|
1
|
+
__all__ = ['DSSM', 'FaceBookDSSM', 'YoutubeDNN', 'YoutubeSBC', 'MIND', 'GRU4Rec', 'NARM', 'SASRec', 'SINE', 'STAMP', 'ComirecDR', 'ComirecSA']
|
|
2
|
+
|
|
3
|
+
from .comirec import ComirecDR, ComirecSA
|
|
4
|
+
from .dssm import DSSM
|
|
5
|
+
from .dssm_facebook import FaceBookDSSM
|
|
6
|
+
from .gru4rec import GRU4Rec
|
|
7
|
+
from .mind import MIND
|
|
8
|
+
from .narm import NARM
|
|
9
|
+
from .sasrec import SASRec
|
|
10
|
+
from .sine import SINE
|
|
11
|
+
from .stamp import STAMP
|
|
12
|
+
from .youtube_dnn import YoutubeDNN
|
|
13
|
+
from .youtube_sbc import YoutubeSBC
|